Avoid inlining large arrays in JaxInMemoryRandomSampleIterator
ethanluoyc opened this issue · comments
Yicheng Luo commented
The JaxInMemoryRandomSampleIterator currently inlines the in-memory dataset. See
https://github.com/deepmind/acme/blob/master/acme/datasets/tfds.py#L199-L200
This causes some OOM issues due to some issues in XLA and also when running on GPU the process might hang. I have filed a more detailed issue in the JAX project google/jax#14080 and the authors recommend not inlining the array instead. I can create a PR if the developers would like to fix that.