RobertTLange / gymnax

RL Environments in JAX 🌍

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Potential bug in Tuple space

joefarrington opened this issue · comments

Hi Rob,

Great library, thanks for all the hard work. Have been using some custom Gymnax environements in recent work (e.g. https://arxiv.org/abs/2303.10672).

There seem to be issues with the sample() and contains() methods for the Tuple space.

For example:

import gymnax.environments.spaces as spaces
s = spaces.Tuple([spaces.Discrete(5), spaces.Discrete(5)])
s.sample(rng=jax.random.PRNGKey(0))

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[95], line 3
      1 import gymnax.environments.spaces as spaces
      2 s = spaces.Tuple([spaces.Discrete(5), spaces.Discrete(5)])
----> 3 s.sample(rng=jax.random.PRNGKey(0))

File [~/miniconda3/envs/gymnax_and_libs/lib/python3.9/site-packages/gymnax/environments/spaces.py:119](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/joefarrington/other_learning/gymnax_and_libs/~/miniconda3/envs/gymnax_and_libs/lib/python3.9/site-packages/gymnax/environments/spaces.py:119), in Tuple.sample(self, rng)
    116 """Sample random action from all subspaces."""
    117 key_split = jax.random.split(rng, self.num_spaces)
    118 return tuple(
--> 119     [self.spaces[k].sample(key_split[i]) for i, k in enumerate(self.spaces)]
    120 )

File [~/miniconda3/envs/gymnax_and_libs/lib/python3.9/site-packages/gymnax/environments/spaces.py:119](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/joefarrington/other_learning/gymnax_and_libs/~/miniconda3/envs/gymnax_and_libs/lib/python3.9/site-packages/gymnax/environments/spaces.py:119), in (.0)
    116 """Sample random action from all subspaces."""
    117 key_split = jax.random.split(rng, self.num_spaces)
    118 return tuple(
--> 119     [self.spaces[k].sample(key_split[i]) for i, k in enumerate(self.spaces)]
    120 )

TypeError: list indices must be integers or slices, not Discrete
s.contains((1, 1))
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[96], line 1
----> 1 s.contains((1, 1))

File [~/miniconda3/envs/gymnax_and_libs/lib/python3.9/site-packages/gymnax/environments/spaces.py:129](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/joefarrington/other_learning/gymnax_and_libs/~/miniconda3/envs/gymnax_and_libs/lib/python3.9/site-packages/gymnax/environments/spaces.py:129), in Tuple.contains(self, x)
    127 out_of_space = 0
    128 for space in self.spaces:
--> 129     out_of_space += 1 - space.contains(x)
    130 return out_of_space == 0

File [~/miniconda3/envs/gymnax_and_libs/lib/python3.9/site-packages/gymnax/environments/spaces.py:44](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/joefarrington/other_learning/gymnax_and_libs/~/miniconda3/envs/gymnax_and_libs/lib/python3.9/site-packages/gymnax/environments/spaces.py:44), in Discrete.contains(self, x)
     41 """Check whether specific object is within space."""
     42 # type_cond = isinstance(x, self.dtype)
     43 # shape_cond = (x.shape == self.shape)
---> 44 range_cond = jnp.logical_and(x >= 0, x < self.n)
     45 return range_cond

TypeError: '>=' not supported between instances of 'tuple' and 'int'

I think the following fixes it, happy to raise a pull request if you want.

class Tuple(Space):
    """Minimal jittable class for tuple (product) of jittable spaces."""

    def __init__(self, spaces: Sequence[Space]):
        self.spaces = spaces
        self.num_spaces = len(spaces)

    def sample(self, rng: chex.PRNGKey) -> Tuple[chex.Array]:
        """Sample random action from all subspaces."""
        key_split = jax.random.split(rng, self.num_spaces)
        return tuple(
            [s.sample(key_split[i]) for i, s in enumerate(self.spaces)]
        )

    def contains(self, x: jnp.int_) -> bool:
        """Check whether dimensions of object are within subspace."""
        # type_cond = isinstance(x, tuple)
        # num_space_cond = len(x) != len(self.spaces)
        # Check for each space individually
        out_of_space = 0
        for i,space in enumerate(self.spaces):
            out_of_space += 1 - space.contains(x[i])
        return out_of_space == 0

Hi @joefarrington -- that is so awesome! Thank you for sharing your publication. Parallelization of classic dynamic programming techniques seems like a great application for JAX/gymnax.

And thank you for sharing the bug/fix. I quickly added it in 4c0628e since I want to make a release fairly soon. But if there are any other bugs / recommendations and you feel like opening PRs, I would be happy to review/integrate them.

Have a good Sunday, Rob