araffin / sbx

SBX: Stable Baselines Jax (SB3 + Jax)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

crash when using a custom network architecture

Robokan opened this issue · comments

Important Note: We do not do technical support, nor consulting and don't answer personal questions per email.
Please post your question on the RL Discord, Reddit or Stack Overflow in that case.

🤖 Custom Gym Environment

Please check your environment first using:

from stable_baselines3.common.env_checker import check_env

env = CustomEnv(arg1, ...)
# It will check your custom environment and output additional warnings if needed
check_env(env)

it passes the check_env

### Describe the bug

when using a custom network architecture: dict(net_arch=[1000, 500]) is fails in SBX code

A clear and concise description of what the bug is.

### Code example

net = dict(net_arch=[1000, 500])
PPOmodel = PPO('MlpPolicy', env, policy_kwargs=net)

Please try to provide a minimal example to reproduce the bug.
For a custom environment, you need to give at least the observation space, action space, reset() and step() methods
(see working example below).
Error messages and stack traces are also helpful.

Traceback (most recent call last):
File "/Users/eric/Documents/development/deepLearning/deepMind/sparky/train.py", line 12, in
t.train()
File "/Users/eric/Documents/development/deepLearning/deepMind/sparky/trainEnviornment.py", line 280, in train
PPOmodel = PPO('MlpPolicy', env,
File "/Users/eric/miniconda3/envs/py39/lib/python3.9/site-packages/sbx/ppo/ppo.py", line 165, in init
self._setup_model()
File "/Users/eric/miniconda3/envs/py39/lib/python3.9/site-packages/sbx/ppo/ppo.py", line 171, in _setup_model
self.policy = self.policy_class( # pytype:disable=not-instantiable
File "/Users/eric/miniconda3/envs/py39/lib/python3.9/site-packages/sbx/ppo/policies.py", line 102, in init
self.n_units = net_arch[0]["pi"][0]
TypeError: 'int' object is not subscriptable

Please use the markdown code blocks
for both code and stack traces.

import gym
import numpy as np

from stable_baselines3 import A2C
from stable_baselines3.common.env_checker import check_env


class CustomEnv(gym.Env):

  def __init__(self):
    super(CustomEnv, self).__init__()
    self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(14,))
    self.action_space = gym.spaces.Box(low=-1, high=1, shape=(6,))

  def reset(self):
    return self.observation_space.sample()

  def step(self, action):
    obs = self.observation_space.sample()
    reward = 1.0
    done = False
    info = {}
    return obs, reward, done, info

env = CustomEnv()
check_env(env)

model = A2C("MlpPolicy", env, verbose=1).learn(1000)
Traceback (most recent call last): File ...

### System Info
Describe the characteristic of your environment:

  • Describe how the library was installed (pip, docker, source, ...)
  • GPU models and configuration
  • Python version
  • PyTorch version
  • Gym version
  • Versions of any other relevant libraries

You can use sb3.get_system_info() to print relevant packages info:

import stable_baselines3 as sb3
sb3.get_system_info()

Additional context

Add any other context about the problem here.

### Checklist

  • I have read the documentation (required)
  • I have checked that there is no similar issue in the repo (required)
  • I have checked my env using the env checker (required)
  • I have provided a minimal working example to reproduce the bug (required)

Hello,
SBX has only basic support for custom policies, and only support independent networks between actor and critic.
Best is to actually take a look at the code to define custom policy for now.

I am not too familiar with the internals of PyTorch. Does SBX allow for different sized hidden layers? All I am trying to do is to do is have the first hidden layer 1000 neurons and the second 500 on both the independent actor and critic networks.

I'm talking about SBX internals (so written in Jax).
Take a look at https://github.com/araffin/sbx/blob/master/sbx/ppo/policies.py#L21-L58 it should be pretty straightforward to adapt to what you want.

Great, yes that looks pretty easy to do.