Toni-SM / skrl

Modular reinforcement learning library (on PyTorch and JAX) with support for NVIDIA Isaac Gym, Omniverse Isaac Gym and Isaac Lab

Home Page:https://skrl.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

PPO discrete action for gym.Env

pavelxx1 opened this issue · comments

Hi @Toni-SM I'm new in RL
but , do u have any example of code for PPO discrete action space?
Thx

Hi @pavelxx1

Sure. The key idea is to replace the gaussian-based policy (for continuous action spaces) by a categorical-based policy (for discrete action spaces).

In the .zip file you will find two examples: for the OpenAI Gym and Farama Gymnasium environment interfaces.
ppo_cartpole_examples.zip

Note that maximum possible total reward varies between the different CartPole environment versions:

  • CartPole-v0: 200
  • CartPole-v1: 500

Thx a lot! I will use your code as starting point of my research
And can u give example of code for agent deterministic testing after success training

I wrote some eval code but is this right way?

import gym

import torch.nn as nn
import torch.nn.functional as F

# Import the skrl components to build the RL system
from skrl.models.torch import Model, CategoricalMixin
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
from skrl.trainers.torch import SequentialTrainer
from skrl.envs.torch import wrap_env



class Policy(CategoricalMixin, Model):
    def __init__(self, observation_space, action_space, device, unnormalized_log_prob=True):
        Model.__init__(self, observation_space, action_space, device)
        CategoricalMixin.__init__(self, unnormalized_log_prob)

        self.linear_layer_1 = nn.Linear(self.num_observations, 64)
        self.linear_layer_2 = nn.Linear(64, 64)
        self.output_layer = nn.Linear(64, self.num_actions)

    def compute(self, inputs, role):
        x = F.relu(self.linear_layer_1(inputs["states"]))
        x = F.relu(self.linear_layer_2(x))
        return self.output_layer(x), {}


env = wrap_env(TestEnv())
device = env.device


models_ppo = {}
models_ppo["policy"] = Policy(env.observation_space, env.action_space, device)

cfg_ppo = PPO_DEFAULT_CONFIG.copy()
cfg_ppo["random_timesteps"] = 0  
cfg_ppo["experiment"]["checkpoint_interval"] = 0

agent_ppo = PPO(models=models_ppo,
                memory=None,
                cfg=cfg_ppo,
                observation_space=env.observation_space,
                action_space=env.action_space,
                device=device)

agent_ppo.load("./rl-ckpt/23-06-15_14-52-30-759890_PPO/checkpoints/agent_100000.pt")

# Configure and instantiate the RL trainer
cfg_trainer = {"timesteps": 1000, "headless": True, "disable_progressbar": True}
trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent_ppo)

# start training
trainer.eval()
-----------------------------------------------------------------
[skrl:INFO] Environment class: gym.core.Env
[skrl:INFO] Environment wrapper: Gym
[skrl:WARNING] Cannot load the value module. The agent doesn't have such an instance
[skrl:WARNING] Cannot load the optimizer module. The agent doesn't have such an instance
[skrl:WARNING] Cannot load the state_preprocessor module. The agent doesn't have such an instance
[skrl:WARNING] Cannot load the value_preprocessor module. The agent doesn't have such an instance_

Hi @pavelxx1

Yes, the code for evaluation looks good.
The warnings are related to some components (Value, optimizer and preprocesors) that are only required during training. So, just ignore them.

One think, I typically load the agent checkpoint after instantiate the trainer, to make sure agent initialization (that occur when the trainer is instantiated) is done.

agent_ppo = PPO(models=models_ppo,
                memory=None,
                cfg=cfg_ppo,
                observation_space=env.observation_space,
                action_space=env.action_space,
                device=device)


# Configure and instantiate the RL trainer
cfg_trainer = {"timesteps": 1000, "headless": True, "disable_progressbar": True}
trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent_ppo)

# load checkpoint
agent_ppo.load("./rl-ckpt/23-06-15_14-52-30-759890_PPO/checkpoints/agent_100000.pt")

# start evaluation
trainer.eval()