google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Non working model when exporting to Huggingface

peregilk opened this issue · comments

commented

I have trained a RoBERTa base Norwegian according to instructions given at https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling#masked-language-modeling.

The final mlm accuracy is 0.63, indicating a working model.

I am trying to load the model, and export til to PyTorch (or TF) for using the inference widget on Hugging Face.

The following code runs without errors:

from transformers import AutoTokenizer, RobertaForMaskedLM
model = RobertaForMaskedLM.from_pretrained('model_dir', from_flax=True)
tokenizer = AutoTokenizer.from_pretrained('model_dir')
model.save_frompretrained('.')
tokenizer.save_frompretrained('.')

Example widget here: https://huggingface.co/pere/norwegian-roberta-base?text=Dette+er+en+%3Cmask%3E.

The outputs makes absolutely no sense. What is the correct way of exporting a Flax model (with and without the MLM head)?

Hi @peregilk, can you please file this again the huggingface transformers repo?

commented

Of course. For reference, here is the link to the new post: huggingface/transformers#12506