google / trax

Trax — Deep Learning with Clear Code and Speed

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

nan CrossEntropyLossWithLogSoftmax while training an NMT Reformer

amrnablus opened this issue · comments

Description

I'm training a Reformer-based NMT model, the code is pretty much identical to https://github.com/google/trax/blob/283cbda9cb87f4a25a952d4c302aedfe54a65850/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb with a custom dataset. The model itself looks like this:

model = trax.models.Reformer(
    input_vocab_size=39901,
    d_model=512, d_ff=2048, dropout=0.1,
    n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
    max_len=512, mode='train')

after training for 10000 epochs, the Loss coverages to nan:

Step 11000: Ran 1000 train steps in 1182.59 secs
Step 11000: train CrossEntropyLossWithLogSoftmax | nan
Step 11000: eval CrossEntropyLossWithLogSoftmax | nan
Step 11000: eval WeightedCategoryAccuracy | 0.00000000

Any idea what would the reason be?

...

Environment information

AWS g3.8xlarge / 2 Tesla M60 GPUs running ubuntu 18

OS: Ubuntu 18.04

$ pip freeze | grep trax
trax==1.4.1


$ pip freeze | grep tensor
mesh-tensorflow==0.1.19
tensor2tensor==1.15.7
tensorboard==2.7.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow==2.4.4
tensorflow-addons==0.14.0
tensorflow-datasets==4.4.0
tensorflow-estimator==2.4.0
tensorflow-gan==2.1.0
tensorflow-gpu==2.4.0
tensorflow-hub==0.12.0
tensorflow-metadata==1.4.0
tensorflow-probability==0.7.0
tensorflow-text==2.4.1


$ pip freeze | grep jax
jax==0.2.24
jaxlib @ https://storage.googleapis.com/jax-releases/cuda110/jaxlib-0.1.70+cuda110-cp37-none-manylinux2010_x86_64.whl


$ python -V
Python 3.7.12

For bugs: reproduction and error logs

# Steps to reproduce:
...
# Error logs:
...