google / trax

Trax — Deep Learning with Clear Code and Speed

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RAM crash while optimizing a linear regression model

aycandv opened this issue · comments

Description

I was trying to do an optimization to test memory usage using 10000 Dense(2) layers in serial. Generally, the time elapsed in step 1 increases as the complexity of the model or batch size increases.

For the model I have used, RAM usage graph was shown as following:
image

At the time I took this screenshot, step 1 was not completed yet.

I tried to do the same task using PyTorch, and it was not complaining about RAM usage. I wonder what I missed in Trax, or is this what is supposed to happen?

Environment information

OS: Colab

$ pip freeze | grep trax
trax @ git+https://github.com/google/trax@6151599222f7f96d9be7336a78c8f596fe509059


$ pip freeze | grep tensor
mesh-tensorflow==0.1.19
tensorboard==2.6.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow @ file:///tensorflow-2.6.0-cp37-cp37m-linux_x86_64.whl
tensorflow-datasets==4.0.1
tensorflow-estimator==2.6.0
tensorflow-gcs-config==2.6.0
tensorflow-hub==0.12.0
tensorflow-metadata==1.2.0
tensorflow-probability==0.13.0
tensorflow-text==2.6.0

$ pip freeze | grep jax
jax==0.2.19
jaxlib @ https://storage.googleapis.com/jax-releases/cuda110/jaxlib-0.1.70+cuda110-cp37-none-manylinux2010_x86_64.whl

$ python -V
Python 3.7.11