Cannot export model from trax to TF
AndriCcos opened this issue · comments
Description
I would like to export trax trained model as tf object, to serve it in tensorflow serving
...
Environment information
Google Colab
$ pip freeze | grep trax
trax==1.4.1
$ pip freeze | grep tensor
mesh-tensorflow==0.1.19
tensor2tensor==1.15.7
tensorboard==2.7.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow @ file:///tensorflow-2.7.0-cp37-cp37m-linux_x86_64.whl
tensorflow-addons==0.15.0
tensorflow-datasets==4.0.1
tensorflow-estimator==2.7.0
tensorflow-gan==2.1.0
tensorflow-gcs-config==2.7.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.22.0
tensorflow-metadata==1.4.0
tensorflow-probability==0.7.0
tensorflow-text==2.7.3
$ pip freeze | grep jax
jax==0.2.25
jaxlib @ https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.1.74+cuda11.cudnn805-cp37-none-manylinux2010_x86_64.whl
$ python -V
Python 3.7.12
For bugs: reproduction and error logs
# Steps to reproduce:
Followed this guide to develop an NMT with transformer model:
https://colab.research.google.com/github/OmarAlsaqa/trax/blob/master/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb
After training the model, I attempted to save it using this guide
https://trax-ml.readthedocs.io/en/latest/notebooks/tf_numpy_and_keras.html#2.-Convert-Trax-to-Keras
However, an error came up
# Error logs:
---------------------------------------------------------------------------
StagingError Traceback (most recent call last)
<ipython-input-103-7c737325b026> in <module>()
1 # Create a full Keras model using the layer from Trax.
2 inputs = tf.keras.Input(shape=(None,), dtype='int32')
----> 3 hidden = keras_layer(inputs)
4 # You can add other Keras layers here operating on hidden.
5 outputs = hidden
1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
697 except Exception as e: # pylint:disable=broad-except
698 if hasattr(e, 'ag_error_metadata'):
--> 699 raise e.ag_error_metadata.to_exception(e)
700 else:
701 raise
StagingError: Exception encountered when calling layer "as_keras_3" (type AsKeras).
in user code:
File "/usr/local/lib/python3.7/dist-packages/trax/trax2keras.py", line 184, in call *
outputs, new_state = self._trax_layer.pure_fn(inputs, weights=weights,
File "/usr/local/lib/python3.7/dist-packages/trax/layers/base.py", line 605, in pure_fn *
raise LayerError(name, 'pure_fn',
LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/trax/models/transformer.py, line 390
layer input shapes: ShapeDtype{shape:(None, None), dtype:<class 'numpy.int32'>}
File [...]/autograph/operators/control_flow.py, line 1324, in if_stmt
_py_if_stmt(cond, body, orelse)
File [...]/autograph/operators/control_flow.py, line 1377, in _py_if_stmt
return body() if cond else orelse()
File [...]//tmp/__autograph_generated_file26d10zeu.py, line 61, in if_body_2
outputs = ag__.converted_call(ag__.ld(self).forward, (ag__.ld(x),), None, fscope)
File [...]/autograph/impl/api.py, line 447, in converted_call
result = converted_f(*effective_args)
File [...]//tmp/__autograph_generated_fileiyqcj2t8.py, line 11, in tf__forward
ag__.converted_call(ag__.ld(self)._validate_forward_inputs, (ag__.ld(xs),), None, fscope)
File [...]/autograph/impl/api.py, line 447, in converted_call
result = converted_f(*effective_args)
File [...]//tmp/__autograph_generated_files2ty70o0.py, line 20, in tf___validate_forward_inputs
ag__.if_stmt(ag__.and_((lambda : ag__.not_(ag__.converted_call(ag__.ld(isinstance), (ag__.ld(xs), (ag__.ld(tuple), ag__.ld(list))), None, fscope))), (lambda : (ag__.ld(self)._n_in != 1))), if_body, else_body, get_state, set_state, (), 0)
File [...]/autograph/operators/control_flow.py, line 1324, in if_stmt
_py_if_stmt(cond, body, orelse)
File [...]/autograph/operators/control_flow.py, line 1377, in _py_if_stmt
return body() if cond else orelse()
File [...]//tmp/__autograph_generated_files2ty70o0.py, line 16, in if_body
raise ag__.converted_call(ag__.ld(TypeError), (f'Serial.forward input must be a tuple or list; instead got {ag__.converted_call(ag__.ld(type), (ag__.ld(xs),), None, fscope)}.',), None, fscope)
TypeError: Serial.forward input must be a tuple or list; instead got <class 'tensorflow.python.framework.ops.Tensor'>.
Call arguments received:
• inputs=tf.Tensor(shape=(None, None), dtype=int32)
To answer my question, this can be achieved by running the following code.
Make sure to set mode to eval as otherwise it will not work as intended.
model = trax.models.Transformer(
input_vocab_size=32768,
d_model=512, d_ff=1024,
n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
max_len=1024,
mode='eval' )
model.init_from_file( model_folder+'/model.pkl.gz' )
keras_layer = trax.AsKeras(model, batch_size=1)
inputs = tf.keras.Input(shape=(1024,), dtype='int32')
hidden = keras_layer((inputs, inputs))
outputs = hidden
keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)