ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.

Home Page:https://ray.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Rllib] Rllib provides wrong state batch size during "bug check" batches on torch custom model

Phirefly9 opened this issue · comments

What happened + What you expected to happen

when starting PPO rllib does several "bug check" batches (or at least I think that is their use) to verify the model is working and functional before it starts training

when using a model that requires a state rllib is providing the incorrect batch size for the hidden state. causing the training process to shut down and crash.

It is possible to hack around it by checking the batch size of the two and expanding the state size, but I have no idea if this makes training unstable and is very annoying

The script below produces the error (PPO pid=2931067) RuntimeError: Input batch size 32 doesn't match hidden0 batch size 4 [repeated 5x across cluster]

Versions / Dependencies

ray 2.9.3 and ray 2.20.0 have this issue

Reproduction script

import gymnasium
from torch.nn import LSTMCell
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.misc import SlimFC, normc_initializer
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import Dict, List, ModelConfigDict, TensorType
import os
from ray import tune

torch, nn = try_import_torch()

from ray.tune.registry import get_trainable_cls


class LSTMCELL(TorchModelV2, nn.Module):

    def __init__(
        self,
        obs_space: gymnasium.spaces.Space,
        action_space: gymnasium.spaces.Space,
        num_outputs: int,
        model_config: ModelConfigDict,
        name: str,
        *args,
        input_dim: int = 64,
        input_hdim: int = 64,
        **kwargs,
    ):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        self.input_dim = input_dim
        self.hidden_dim = input_hdim

        prev_layer_size = int(obs_space.shape[-1])

        self.encoder = SlimFC(
            in_size=prev_layer_size,
            out_size=input_dim,
            initializer=normc_initializer(0.01),
            activation_fn="tanh",
        )
        self.lstm = LSTMCell(input_size=input_dim, hidden_size=self.hidden_dim)
        # Postprocess LSTM output with another hidden layer and compute values.
        self._logits = SlimFC(self.hidden_dim, self.num_outputs)
        self._value_branch = SlimFC(self.hidden_dim, 1)
        self._features = None
        self._last_flat_in = None

    @override(TorchModelV2)
    def forward(
        self,
        input_dict: Dict[str, TensorType],
        state: List[TensorType],
        seq_lens: TensorType,
    ):
        self._last_flat_in = input_dict["obs_flat"].float()

        obs_embed = self.encoder(self._last_flat_in)
        out = self.lstm(obs_embed, state)
        self._features = out[0]
        logits = self._logits(self._features)

        return logits, list(out)

    @override(TorchModelV2)
    def get_initial_state(self):
        h = [torch.zeros(self.hidden_dim), torch.zeros(self.hidden_dim)]
        return h

    @override(TorchModelV2)
    def value_function(self) -> TensorType:
        assert self._features is not None, "must call forward() first"
        return self._value_branch(self._features).squeeze(1)


ModelCatalog.register_custom_model("LSTMCELL", LSTMCELL)


if __name__ == "__main__":
    algo_cls = get_trainable_cls("PPO")
    config = algo_cls.get_default_config()

    config.environment(env="CartPole-v1").resources(
        num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0"))
    ).framework("torch").reporting(min_time_s_per_iteration=0.1).training(
        model={
            "vf_share_layers": True,
            "custom_model": "LSTMCELL",
            # Extra kwargs to be passed to your model's c'tor.
            "custom_model_config": {
                "input_dim": 32,
                "input_hdim": 32,
            },
        },
    )

    tuner = tune.Tuner(
        "PPO",
        param_space=config,
        # run_config=air.RunConfig(
        #     stop=stop,
        # ),
    )
    results = tuner.fit()

Issue Severity

Medium: It is a significant difficulty but I can work around it.