RobertTLange / gymnax

RL Environments in JAX 🌍

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

On the use of `jnp.int_`

kngwyu opened this issue · comments

Hi, thank you for your work on this great library. When I tried gymnax for the first time, I frequently encountered errors due to the use of jnp.int_.

UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in array is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.

It looks like jnp.int_ is used in:

  • Catch-bsuite
  • Freeway-minatar
  • spaces.Discrete
  • spaces.contain

So my question is: do you have any specific reason to use jnp.int_? If not, I want to contribute to replacing these with jnp.int32 because I don't need JAX_ENABLE_X64 for other usages...

This issue should be fixed by making use of jax.dtypes.canonicalize_dtype. I'll submit a pull request when I get the time.