takuseno / d3rlpy

An offline deep reinforcement learning library

Home Page:https://takuseno.github.io/d3rlpy

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[BUG] error when running dt on atari

AndssY opened this issue · comments

https://github.com/takuseno/d3rlpy/releases/tag/v2.2.0
running

import d3rlpy

dataset, env = d3rlpy.datasets.get_atari_transitions(
    'breakout',
    fraction=0.01,
    num_stack=4,
)

dt = d3rlpy.algos.DiscreteDecisionTransformerConfig(
    batch_size=64,
    num_heads=1,
    learning_rate=1e-4,
    max_timestep=1000,
    num_layers=3,
    position_encoding_type=d3rlpy.PositionEncodingType.SIMPLE,
    encoder_factory=d3rlpy.models.VectorEncoderFactory([128], exclude_last_activation=True),
    observation_scaler=d3rlpy.preprocessing.StandardObservationScaler(),
    context_size=20,
    warmup_tokens=100000,
).create()

dt.fit(
    dataset,
    n_steps=100000,
    n_steps_per_epoch=1000,
    eval_env=env,
    eval_target_return=500,
)


return error

Traceback (most recent call last):
  ......
  ......
  File "...d3rlpy/d3rlpy/preprocessing/observation_scalers.py", line 322, in fit_with_trajectory_slicer
    total_sum += np.sum(traj.observations, axis=0)
ValueError: non-broadcastable output operand with shape (1,84,84) doesn't match the broadcast shape (4,84,84)

@AndssY Hi, thank you for the issue. Please check this example:

You need to change encoder_factory and observation_scaler to resolve the issue.

Thanks!