google / trax

Trax — Deep Learning with Clear Code and Speed

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Cannot export model from trax to TF

AndriCcos opened this issue · comments

Description

I would like to export trax trained model as tf object, to serve it in tensorflow serving
...

Environment information

Google Colab

$ pip freeze | grep trax
trax==1.4.1

$ pip freeze | grep tensor
mesh-tensorflow==0.1.19
tensor2tensor==1.15.7
tensorboard==2.7.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow @ file:///tensorflow-2.7.0-cp37-cp37m-linux_x86_64.whl
tensorflow-addons==0.15.0
tensorflow-datasets==4.0.1
tensorflow-estimator==2.7.0
tensorflow-gan==2.1.0
tensorflow-gcs-config==2.7.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.22.0
tensorflow-metadata==1.4.0
tensorflow-probability==0.7.0
tensorflow-text==2.7.3

$ pip freeze | grep jax
jax==0.2.25
jaxlib @ https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.1.74+cuda11.cudnn805-cp37-none-manylinux2010_x86_64.whl

$ python -V
Python 3.7.12

For bugs: reproduction and error logs

# Steps to reproduce:
Followed this guide to develop an NMT with transformer model:
https://colab.research.google.com/github/OmarAlsaqa/trax/blob/master/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb

After training the model, I attempted to save it using this guide
https://trax-ml.readthedocs.io/en/latest/notebooks/tf_numpy_and_keras.html#2.-Convert-Trax-to-Keras

However, an error came up

# Error logs:

---------------------------------------------------------------------------
StagingError                              Traceback (most recent call last)
<ipython-input-103-7c737325b026> in <module>()
      1 # Create a full Keras  model using the layer from Trax.
      2 inputs = tf.keras.Input(shape=(None,), dtype='int32')
----> 3 hidden = keras_layer(inputs)
      4 # You can add other Keras layers here operating on hidden.
      5 outputs = hidden

1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
    697       except Exception as e:  # pylint:disable=broad-except
    698         if hasattr(e, 'ag_error_metadata'):
--> 699           raise e.ag_error_metadata.to_exception(e)
    700         else:
    701           raise

StagingError: Exception encountered when calling layer "as_keras_3" (type AsKeras).

in user code:

    File "/usr/local/lib/python3.7/dist-packages/trax/trax2keras.py", line 184, in call  *
        outputs, new_state = self._trax_layer.pure_fn(inputs, weights=weights,
    File "/usr/local/lib/python3.7/dist-packages/trax/layers/base.py", line 605, in pure_fn  *
        raise LayerError(name, 'pure_fn',

    LayerError: Exception passing through layer Serial (in pure_fn):
      layer created in file [...]/trax/models/transformer.py, line 390
      layer input shapes: ShapeDtype{shape:(None, None), dtype:<class 'numpy.int32'>}
    
      File [...]/autograph/operators/control_flow.py, line 1324, in if_stmt
        _py_if_stmt(cond, body, orelse)
    
      File [...]/autograph/operators/control_flow.py, line 1377, in _py_if_stmt
        return body() if cond else orelse()
    
      File [...]//tmp/__autograph_generated_file26d10zeu.py, line 61, in if_body_2
        outputs = ag__.converted_call(ag__.ld(self).forward, (ag__.ld(x),), None, fscope)
    
      File [...]/autograph/impl/api.py, line 447, in converted_call
        result = converted_f(*effective_args)
    
      File [...]//tmp/__autograph_generated_fileiyqcj2t8.py, line 11, in tf__forward
        ag__.converted_call(ag__.ld(self)._validate_forward_inputs, (ag__.ld(xs),), None, fscope)
    
      File [...]/autograph/impl/api.py, line 447, in converted_call
        result = converted_f(*effective_args)
    
      File [...]//tmp/__autograph_generated_files2ty70o0.py, line 20, in tf___validate_forward_inputs
        ag__.if_stmt(ag__.and_((lambda : ag__.not_(ag__.converted_call(ag__.ld(isinstance), (ag__.ld(xs), (ag__.ld(tuple), ag__.ld(list))), None, fscope))), (lambda : (ag__.ld(self)._n_in != 1))), if_body, else_body, get_state, set_state, (), 0)
    
      File [...]/autograph/operators/control_flow.py, line 1324, in if_stmt
        _py_if_stmt(cond, body, orelse)
    
      File [...]/autograph/operators/control_flow.py, line 1377, in _py_if_stmt
        return body() if cond else orelse()
    
      File [...]//tmp/__autograph_generated_files2ty70o0.py, line 16, in if_body
        raise ag__.converted_call(ag__.ld(TypeError), (f'Serial.forward input must be a tuple or list; instead got {ag__.converted_call(ag__.ld(type), (ag__.ld(xs),), None, fscope)}.',), None, fscope)
    
    TypeError: Serial.forward input must be a tuple or list; instead got <class 'tensorflow.python.framework.ops.Tensor'>.


Call arguments received:
  • inputs=tf.Tensor(shape=(None, None), dtype=int32)

To answer my question, this can be achieved by running the following code.
Make sure to set mode to eval as otherwise it will not work as intended.

model = trax.models.Transformer(
input_vocab_size=32768,
d_model=512, d_ff=1024,
n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
max_len=1024,
mode='eval' )

model.init_from_file( model_folder+'/model.pkl.gz' )

keras_layer = trax.AsKeras(model, batch_size=1)
inputs = tf.keras.Input(shape=(1024,), dtype='int32')
hidden = keras_layer((inputs, inputs))

outputs = hidden
keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)