env step accumulates memory
ManuelEberhardinger opened this issue · comments
Hi Robert,
Thanks for this awesome library!
I use the gymnax library on CPU to collect data for the Breakout MinAtar environment. I generate thousands of random programs and want to execute them on the env. Somehow the memory accumulates over time so that I get RAM problems. I used the python memory profiler and could detect that, the step function of the env always add about 10MB after each call. Do you know why that is the case? Is this maybe only the case when running Jax on CPU?
I had problems getting Jax and Pytorch running in the same virtual env on Cuda so I thought, I just run gymnax on the CPU to avoid Cuda problems. The memory is also not released in the next step of the loop or at the end of the function..
I used the code from the visualization notebook as a reference.
Thanks a lot for your answer!
Best wishes,
Manuel
Line # Mem usage Increment Occurrences Line Contents
=============================================================
75 899.3 MiB 899.3 MiB 1 @profile
76 def rollout_episode(env_data, env_params, seq_length, program):
77 899.3 MiB 0.0 MiB 1 env, obs, env_state = env_data
78 899.3 MiB 0.0 MiB 1 rng = jax.random.PRNGKey(0)
79 899.3 MiB 0.0 MiB 1 examples = []
80 908.6 MiB 0.0 MiB 6 for _ in range(seq_length):
81 908.6 MiB 0.0 MiB 6 rng, rng_act, rng_step = jax.random.split(rng, 3)
82 908.6 MiB 0.0 MiB 6 obs_inp = convert_to_task_input(obs)
83 908.6 MiB 0.0 MiB 6 input_ex = (obs_inp,)
84 # output ex is the action
85 908.6 MiB 0.0 MiB 18 output_ex = runWithTimeout(lambda: program.runWithArguments(input_ex), None)
86 908.6 MiB 0.0 MiB 6 examples.append((input_ex, output_ex))
87
88 908.6 MiB 9.3 MiB 12 next_obs, next_env_state, reward, done, info = env.step(
89 908.6 MiB 0.0 MiB 6 rng_step, env_state, output_ex, env_params
90 )
91
92 908.6 MiB 0.0 MiB 6 if done: # or t_counter == max_frames:
93 908.6 MiB 0.0 MiB 1 break
94 else:
95 908.6 MiB 0.0 MiB 5 env_state = next_env_state
96 908.6 MiB 0.0 MiB 5 obs = next_obs
97 908.6 MiB 0.0 MiB 1 return examples````
I didn't find out why Jax behaves this way and so I switched back to the official MinAtar implementation and now the code runs 10x faster and I also have no memory issues anymore. So I will close this issue for now, but I think it is strange that Jax accumulates the memory without releasing it.