RobertTLange / gymnax

RL Environments in JAX 🌍

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

env step accumulates memory

ManuelEberhardinger opened this issue · comments

Hi Robert,

Thanks for this awesome library!

I use the gymnax library on CPU to collect data for the Breakout MinAtar environment. I generate thousands of random programs and want to execute them on the env. Somehow the memory accumulates over time so that I get RAM problems. I used the python memory profiler and could detect that, the step function of the env always add about 10MB after each call. Do you know why that is the case? Is this maybe only the case when running Jax on CPU?

I had problems getting Jax and Pytorch running in the same virtual env on Cuda so I thought, I just run gymnax on the CPU to avoid Cuda problems. The memory is also not released in the next step of the loop or at the end of the function..

I used the code from the visualization notebook as a reference.

Thanks a lot for your answer!

Best wishes,
Manuel

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    75    899.3 MiB    899.3 MiB           1   @profile
    76                                         def rollout_episode(env_data, env_params, seq_length, program):
    77    899.3 MiB      0.0 MiB           1       env, obs, env_state = env_data
    78    899.3 MiB      0.0 MiB           1       rng = jax.random.PRNGKey(0)
    79    899.3 MiB      0.0 MiB           1       examples = []
    80    908.6 MiB      0.0 MiB           6       for _ in range(seq_length):
    81    908.6 MiB      0.0 MiB           6           rng, rng_act, rng_step = jax.random.split(rng, 3)
    82    908.6 MiB      0.0 MiB           6           obs_inp = convert_to_task_input(obs)
    83    908.6 MiB      0.0 MiB           6           input_ex = (obs_inp,)
    84                                                 # output ex is the action
    85    908.6 MiB      0.0 MiB          18           output_ex = runWithTimeout(lambda: program.runWithArguments(input_ex), None)
    86    908.6 MiB      0.0 MiB           6           examples.append((input_ex, output_ex))
    87                                         
    88    908.6 MiB      9.3 MiB          12           next_obs, next_env_state, reward, done, info = env.step(
    89    908.6 MiB      0.0 MiB           6               rng_step, env_state, output_ex, env_params
    90                                                 )
    91                                         
    92    908.6 MiB      0.0 MiB           6           if done:  # or t_counter == max_frames:
    93    908.6 MiB      0.0 MiB           1               break
    94                                                 else:
    95    908.6 MiB      0.0 MiB           5               env_state = next_env_state
    96    908.6 MiB      0.0 MiB           5               obs = next_obs
    97    908.6 MiB      0.0 MiB           1       return examples```` 

I didn't find out why Jax behaves this way and so I switched back to the official MinAtar implementation and now the code runs 10x faster and I also have no memory issues anymore. So I will close this issue for now, but I think it is strange that Jax accumulates the memory without releasing it.