google-deepmind / acme

A library of reinforcement learning components and agents

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Avoid inlining large arrays in JaxInMemoryRandomSampleIterator

ethanluoyc opened this issue · comments

The JaxInMemoryRandomSampleIterator currently inlines the in-memory dataset. See
https://github.com/deepmind/acme/blob/master/acme/datasets/tfds.py#L199-L200

This causes some OOM issues due to some issues in XLA and also when running on GPU the process might hang. I have filed a more detailed issue in the JAX project google/jax#14080 and the authors recommend not inlining the array instead. I can create a PR if the developers would like to fix that.