google / trax

Trax — Deep Learning with Clear Code and Speed

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

ValueError on predict mode Transformer model

shashank2123 opened this issue · comments

Description

I am getting following error if i load model in predict model. it works perfectly in eval mode.

ValueError: Incompatible shapes for matmul arguments: (8, 1, 64) and (256, 64, 2048)

model definitions

model = trax.models.Transformer(
      input_vocab_size=33600,
      d_model=512, d_ff=2048, dropout = 0.1,
      n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
      max_len=2048, mode=mode) 

...

Environment information

OS:  I am using colab

$ pip freeze | grep trax
trax                          1.3.9  

$ pip freeze | grep tensor

mesh-tensorflow==0.1.19
tensorboard==2.5.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow==2.5.0
tensorflow-datasets==4.0.1
tensorflow-estimator==2.5.0
tensorflow-gcs-config==2.5.0
tensorflow-hub==0.12.0
tensorflow-metadata==1.1.0
tensorflow-probability==0.13.0
tensorflow-text==2.5.0

$ pip freeze | grep jax
jax==0.2.17
jaxlib==0.1.69+cuda110

$ python -V
python 3.7

### For bugs: reproduction and error logs

Steps to reproduce:

...

def sampling_decode(input_sentence, model = None, temperature=0.0, vocab_file=None, vocab_dir=None):

    input_tokens = tokenize(input_sentence, vocab_file=vocab_file, vocab_dir=vocab_dir)
    
    cur_output_tokens = []
   
    cur_output = 0  
    
    EOS = 1
    
    while cur_output != EOS: 
        
        cur_output, log_prob = next_symbol(model, input_tokens, cur_output_tokens, temperature)
        
        cur_output_tokens.append(cur_output) 
    
    sentence = detokenize(cur_output_tokens, vocab_file=vocab_file, vocab_dir=vocab_dir)

    return cur_output_tokens, log_prob, sentence

eval_point = random.choice(eval_data)

incorrect_sentence = eval_point[0]
correct_sentence = eval_point[1]

print("Incorrect sentence :- ",incorrect_sentence)
print("Correct sentence :- ",correct_sentence)

pred_token, log_prob, pred_sentence = sampling_decode(incorrect_sentence, eval_model, temperature=0.0, vocab_file=vocab_file, vocab_dir=vocab_dir)

print("Predicted sentence :- ",pred_sentence)
print("correct token :- ",pred_token)
print("log_prob :- ",log_prob)

Error logs:

...
LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/trax/models/transformer.py, line 390
layer input shapes: (ShapeDtype{shape:(1, 204), dtype:int64}, ShapeDtype{shape:(1, 1), dtype:int64})

File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/trax/models/transformer.py, line 566
layer input shapes: ShapeDtype{shape:(1, 1, 512), dtype:float32}

File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Branch (in pure_fn):
layer created in file [...]/trax/models/transformer.py, line 566
layer input shapes: ShapeDtype{shape:(1, 1, 512), dtype:float32}

File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Parallel (in pure_fn):
layer created in file [...]/trax/models/transformer.py, line 566
layer input shapes: (ShapeDtype{shape:(1, 1, 512), dtype:float32}, ShapeDtype{shape:(1, 1, 512), dtype:float32})

File [...]/trax/layers/combinators.py, line 211, in forward
sub_outputs, sub_state = layer.pure_fn(x, w, s, r, use_cache=True)

LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/trax/models/transformer.py, line 566
layer input shapes: ShapeDtype{shape:(1, 1, 512), dtype:float32}

File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/trax/models/transformer.py, line 556
layer input shapes: ShapeDtype{shape:(1, 1, 512), dtype:float32}

File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/trax/models/transformer.py, line 556
layer input shapes: ShapeDtype{shape:(1, 1, 512), dtype:float32}

File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/trax/models/transformer.py, line 556
layer input shapes: ShapeDtype{shape:(1, 1, 512), dtype:float32}

File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer DotProductCausalAttention (in pure_fn):
layer created in file [...]/trax/models/transformer.py, line 556
layer input shapes: (ShapeDtype{shape:(8, 1, 64), dtype:float32}, ShapeDtype{shape:(8, 1, 64), dtype:float32}, ShapeDtype{shape:(8, 1, 64), dtype:float32})

File [...]/trax/layers/assert_shape.py, line 122, in forward_wrapper
y = forward(self, x, *args, **kwargs)

File [...]/trax/layers/attention.py, line 520, in forward
q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=self.rng)

File [...]/trax/layers/attention.py, line 281, in _per_head_attention
dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature)

File [...]/_src/numpy/lax_numpy.py, line 4211, in matmul
.format(shape(a), shape(b)))

ValueError: Incompatible shapes for matmul arguments: (8, 1, 64) and (256, 64, 2048)

I'm also having the exact same problem, my config is:

`trax 1.3.8
mesh-tensorflow 0.1.19
tensor2tensor 1.15.7
tensorboard 2.6.0
tensorboard-data-server 0.6.1
tensorboard-plugin-wit 1.8.0
tensorflow 2.5.0
tensorflow-addons 0.14.0
tensorflow-datasets 4.4.0
tensorflow-estimator 2.5.0
tensorflow-gan 2.1.0
tensorflow-hub 0.12.0
tensorflow-metadata 1.2.0
tensorflow-probability 0.7.0
tensorflow-text 2.5.0
jax 0.2.21
jaxlib 0.1.71+cuda111

OS Ubuntu 20.04.3 LTS

Python version 3.8.12
`
Have you found any solution?