araffin / sbx

SBX: Stable Baselines Jax (SB3 + Jax)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Custom env with FrameStack wrapper causes invalid actions to be passed to `env.step`

capnspacehook opened this issue · comments

🤖 Custom Gym Environment

Describe the bug

When using gymnasium.wrappers.frame_stack.FrameStack with a simple custom env, I get an exception when an action is being chosen in step.

Code example

import itertools
from typing import Any, List, Tuple

import gymnasium as gym
import numpy as np
from gymnasium.spaces import Box, Discrete
from gymnasium.wrappers.frame_stack import FrameStack
from sbx import PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.vec_env import DummyVecEnv


class MyEnv(gym.Env):
    def __init__(self) -> None:
        self.actions, self.action_space = self.actionSpace()
        self.observation_space = Box(0, 1, shape=(1,))

        super().__init__()

    def step(self, action: Any) -> Tuple[Any, float, bool, bool, dict]:
        chosenAction = self.actions[action]

        return self.obs(), 0.0, False, False, {}

    def reset(
        self, *, seed: int | None = None, options: dict | None = None
    ) -> Tuple[Any, dict]:
        super().reset(seed=seed, options=options)
        return self.obs(), {}

    def obs(self):
        return np.array([0.5], dtype=np.float32)

    def render(self) -> Any | List[Any] | None:
        pass

    def actionSpace(self):
        baseActions = [0, 1, 2, 3, 4]

        totalActionsWithRepeats = list(itertools.permutations(baseActions, 2))
        withoutRepeats = []

        for combination in totalActionsWithRepeats:
            reversedCombination = combination[::-1]
            if reversedCombination not in withoutRepeats:
                withoutRepeats.append(combination)

        filteredActions = [[action] for action in baseActions] + withoutRepeats

        return filteredActions, Discrete(len(filteredActions))


if __name__ == "__main__":
    env = MyEnv()
    check_env(env)

    env = FrameStack(env, 4)
    env = DummyVecEnv([lambda: env])

    algo = PPO("MlpPolicy", env)
    algo.learn(total_timesteps=1000)
Traceback (most recent call last):
  File "/home/user/sbx_ppo_repro.py", line 61, in <module>
    algo.learn(total_timesteps=1000)
  File "/home/user/jax-venv/lib/python3.10/site-packages/sbx/ppo/ppo.py", line 315, in learn
    return super().learn(
  File "/home/user/jax-venv/lib/python3.10/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 259, in learn
    continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
  File "/home/user/jax-venv/lib/python3.10/site-packages/sbx/common/on_policy_algorithm.py", line 152, in collect_rollouts
    new_obs, rewards, dones, infos = env.step(clipped_actions)
  File "/home/user/jax-venv/lib/python3.10/site-packages/stable_baselines3/common/vec_env/base_vec_env.py", line 197, in step
    return self.step_wait()
  File "/home/user/jax-venv/lib/python3.10/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py", line 58, in step_wait
    obs, self.buf_rews[env_idx], terminated, truncated, self.buf_infos[env_idx] = self.envs[env_idx].step(
  File "/home/user/jax-venv/lib/python3.10/site-packages/gymnasium/wrappers/frame_stack.py", line 179, in step
    observation, reward, terminated, truncated, info = self.env.step(action)
  File "/home/user/sbx_ppo_repro.py", line 21, in step
    chosenAction = self.actions[action]
TypeError: only integer scalar arrays can be converted to a scalar index

### System Info

  • OS: Linux-6.5.6-76060506-generic-x86_64-with-glibc2.35 # 202310061235169739694522.04~9283e32 SMP PREEMPT_DYNAMIC Sun O
  • Python: 3.10.12
  • Stable-Baselines3: 2.1.0
  • PyTorch: 2.1.0+cu121
  • GPU Enabled: True
  • GPU Model: Nvida RTX 3080ti
  • Numpy: 1.26.1
  • Cloudpickle: 3.0.0
  • Gymnasium: 0.29.1

sbx at the latest commit was installed using pip: pip install git+https://github.com/araffin/sbx

### Checklist

  • I have read the documentation (required)
  • I have checked that there is no similar issue in the repo (required)
  • I have checked my env using the env checker (required)
  • I have provided a minimal working example to reproduce the bug (required)

Hello,
thanks for the bug report.
I guess the issue comes from a flatten layer which is not applied in SBX.

A quick fix is to use a VecFrameStack instead (it stacks on the last axis instead of the first):

from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack

vec_env = DummyVecEnv([lambda: env])
vec_env = VecFrameStack(vec_env, 4)

To reproduce with a even more minimal code:

from typing import Any, List, Tuple

import gymnasium as gym
from gymnasium.spaces import Box, Discrete
from sbx import PPO


class MyEnv(gym.Env):
    def __init__(self) -> None:
        self.observation_space = Box(0, 1, shape=(2, 1), dtype="float32")
        self.action_space = Discrete(15)

    def step(self, action: Any) -> Tuple[Any, float, bool, bool, dict]:
        return self.observation_space.sample(), 0.0, False, False, {}

    def reset(
        self, *, seed: int | None = None, options: dict | None = None
    ) -> Tuple[Any, dict]:
        super().reset(seed=seed, options=options)
        return self.observation_space.sample(), {}

    def render(self) -> Any | List[Any] | None:
        pass

PPO("MlpPolicy", MyEnv()).learn(total_timesteps=1000)

I've pushed a fix in #18, you should be able to upgrade to sbx 0.9.0 soon =)