huggingface / jat

General multi-task deep RL Agent

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RuntimeError: Could not infer dtype of PngImageFile error on training JAT on atari-pong

sanjayharesh opened this issue · comments

Hi,

I was trying to train JAT on atari-pong to test training out and I'm getting the following error.

File ".../jat/env/lib/python3.10/site-packages/transformers/data/data_collator.py", line 158, in torch_default_data_collator
    batch[k] = torch.tensor([f[k] for f in features])
RuntimeError: Could not infer dtype of PngImageFile

Here's the command I'm using to launch the training

python -m scripts.train_jat_tokenized --output_dir checkpoints/jat_small_v100 --model_name_or_path jat-project/jat --tasks atari-pong --trust_remote_code --per_device_train_batch_size 20 --gradient_accumulation_steps 2 --save_steps 10000 --run_name train_jat_small --logging_steps 100 --logging_first_step --dispatch_batches False --dataloader_num_workers 16 --max_steps 250000

I have use the commands in the README to setup the environment.
Note that I have only changed the tasks flag.

Hey, thanks for reporting, but after double checking, I can't reproduce this error. Can you share your system info? Maybe update datasets?

Does the error occurs at the very beginning of the training, or does it run for a few steps?

Hi @qgallouedec ,

Thanks for your reply.
I'm using the environment from the repo. Here's the output of pip freeze if it helps.

absl-py==2.1.0
accelerate==0.30.0
aiohttp==3.9.5
aiosignal==1.3.1
ale-py==0.8.1
arch==5.3.0
asttokens==2.4.1
async-timeout==4.0.3
attrs==23.2.0
AutoROM==0.4.2
AutoROM.accept-rom-license==0.6.1
black==22.12.0
certifi==2024.2.2
cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
cloudpickle==3.0.0
contourpy==1.2.1
cycler==0.12.1
Cython==0.29.37
datasets==2.19.1
decorator==5.1.1
decord==0.6.0
dill==0.3.8
docker-pycreds==0.4.0
etils==1.7.0
exceptiongroup==1.2.1
execnet==2.1.1
executing==2.0.1
Farama-Notifications==0.0.4
fasteners==0.15
filelock==3.14.0
fonttools==4.51.0
free-mujoco-py==2.1.6
frozenlist==1.4.1
fsspec==2024.3.1
gitdb==4.0.11
GitPython==3.1.43
glfw==1.12.0
gym==0.26.2
gym-notices==0.0.8
gymnasium==0.29.1
huggingface-hub==0.23.0
idna==3.7
imageio==2.34.1
importlib_resources==6.4.0
iniconfig==2.0.0
ipython==8.24.0
jat @ file:///local/mnt/jat
jedi==0.19.1
Jinja2==3.1.4
kiwisolver==1.4.5
MarkupSafe==2.1.5
matplotlib==3.8.4
matplotlib-inline==0.1.7
metaworld==0.1.0
minigrid==2.3.1
monotonic==1.6
mpmath==1.3.0
mujoco==3.1.5
mujoco-py==2.1.2.14
multidict==6.0.5
multiprocess==0.70.16
mypy-extensions==1.0.0
natsort==8.4.0
networkx==3.3
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.1.105
opencv-python==4.9.0.80
packaging==24.0
pandas==2.2.2
parso==0.8.4
pathspec==0.12.1
patsy==0.5.6
pexpect==4.9.0
pillow==10.3.0
platformdirs==4.2.1
pluggy==1.5.0
prompt-toolkit==3.0.43
property-cached==1.6.4
protobuf==4.25.3
psutil==5.9.8
ptyprocess==0.7.0
pure-eval==0.2.2
pyarrow==16.0.0
pyarrow-hotfix==0.6
pycparser==2.22
pygame==2.5.2
Pygments==2.18.0
PyOpenGL==3.1.7
pyparsing==3.1.2
pytest==8.2.0
pytest-xdist==3.6.1
python-dateutil==2.9.0.post0
pytz==2024.1
PyYAML==6.0.1
regex==2024.5.10
requests==2.31.0
rliable==1.0.8
ruff==0.4.4
safetensors==0.4.3
scipy==1.13.0
seaborn==0.13.2
sentry-sdk==2.1.1
setproctitle==1.3.3
Shimmy==0.2.1
six==1.16.0
smmap==5.0.1
stack-data==0.6.3
statsmodels==0.14.2
sympy==1.12
tokenize-rt==5.2.0
tokenizers==0.19.1
tomli==2.0.1
torch==2.3.0
torchvision==0.18.0
tqdm==4.66.4
traitlets==5.14.3
transformers==4.40.2
triton==2.3.0
typing_extensions==4.11.0
tzdata==2024.1
urllib3==2.2.1
wandb==0.17.0
wcwidth==0.2.13
xxhash==3.4.1
yarl==1.9.4
zipp==3.18.1

And yes it happens at the very beginning.

Also, the dataset download here had failed as it was unable to find jat-project/jat-dataset-tokenized. So I had replaced it with jat-project/jat-dataset as I couldn't find the tokenized version on huggingface repo either. Could that be the issue?

Guess I should get my eyes checked.

Replacing with the tokenized dataset works! Thanks!

Guess I should get my eyes checked.

Your eyes are fine, I've just opened it, it was private ;)