RobertTLange / gymnax

RL Environments in JAX 🌍

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Replace all `params_env_name ` with `FrozenDict`

RobertTLange opened this issue · comments

Some params dictionaries do specify the shapes of observation. Hence, when jitting we need to mark them as static_argnums. That in turn is only possible if the dictionary is immutable. I propose porting the flax FrozenDict and to provide a helper function called update_env_params(params, x_name, x_value), which unfreezes, changes and freezes the dictionary again.

In order to reduce dependencies, it may make sense to simply copy the file and use the same Apache License.

https://github.com/google/flax/blob/ac0f57419f32c9924e094e7e0dc82a15be228b5d/flax/core/frozen_dict.py

Go through all envs and update the parameter dictionaries.

Adressed in #17.