This repo provides official code and checkpoints for iVideoGPT, a generic and efficient world model architecture that has been pre-trained on millions of human and robotic manipulation trajectories.
- 🚩 2024.09.26: iVideoGPT has been accepted by NeurIPS 2024, congrats!
- 🚩 2024.08.31: Training code is released (Work in progress 🚧 and please stay tuned!)
- 🚩 2024.05.31: Project website with video samples is released.
- 🚩 2024.05.30: Model pre-trained on Open X-Embodiment and inference code are released.
- 🚩 2024.05.27: Our paper is released on arXiv.
conda create -n ivideogpt python==3.9
conda activate ivideogpt
pip install -r requirements.txt
At the moment we provide the following models:
Model | Resolution | Action | Tokenizer Size | Transformer Size |
---|---|---|---|---|
ivideogpt-oxe-64-act-free | 64x64 | No | 114M | 138M |
If no network connection to Hugging Face, you can manually download from Tsinghua Cloud.
python inference/predict.py --pretrained_model_name_or_path "thuml/ivideogpt-oxe-64-act-free" --input_path inference/samples/fractal_sample.npz --dataset_name fractal20220817_data
To try more samples, download the dataset from the Open X-Embodiment Dataset and extract single episodes as follows:
python oxe_data_converter.py --dataset_name {dataset_name, e.g. bridge} --input_path {path to OXE} --output_path samples --max_num_episodes 10
To finetune our pretrained iVideoGPT, download it into pretrained_models/ivideogpt-oxe-64-act-free
.
To evaluate the FVD metric, download pretrained I3D model into pretrained_models/i3d/i3d_torchscript.pt
.
BAIR Robot Pushing: Download the dataset and preprocess with the following script:
wget http://rail.eecs.berkeley.edu/datasets/bair_robot_pushing_dataset_v0.tar -P .
tar -xvf ./bair_robot_pushing_dataset_v0.tar -C .
python datasets/preprocess_bair.py --input_path bair_robot_pushing_dataset_v0/softmotion30_44k --save_path bair_preprocessed
Then modify the saved paths (e.g. bair_preprocessed/train
and bair_preprocessed/test
) in DATASET.yaml
.
accelerate launch train_tokenizer.py \
--exp_name bair_tokenizer_ft --output_dir log_vqgan --seed 0 --mixed_precision bf16 \
--model_type ctx_vqgan \
--train_batch_size 16 --gradient_accumulation_steps 1 --disc_start 1000005 \
--oxe_data_mixes_type bair --resolution 64 --dataloader_num_workers 16 \
--rand_select --video_stepsize 1 --segment_horizon 16 --segment_length 8 --context_length 1 \
--pretrained_model_name_or_path pretrained_models/ivideogpt-oxe-64-act-free/tokenizer
For action-conditioned video prediction, run the following:
accelerate launch train_gpt.py \
--exp_name bair_llama_ft --output_dir log_trm --seed 0 --mixed_precision bf16 \
--vqgan_type ctx_vqgan \
--pretrained_model_name_or_path {log directory of finetuned tokenizer}/unwrapped_model \
--config_name configs/llama/config.json --load_internal_llm --action_conditioned --action_dim 4 \
--pretrained_transformer_path pretrained_models/ivideogpt-oxe-64-act-free/transformer \
--per_device_train_batch_size 16 --gradient_accumulation_steps 1 \
--learning_rate 1e-4 --lr_scheduler_type cosine \
--oxe_data_mixes_type bair --resolution 64 --dataloader_num_workers 16 \
--video_stepsize 1 --segment_length 16 --context_length 1 \
--use_eval_dataset --use_fvd --use_frame_metrics \
--weight_decay 0.01 --llama_attn_drop 0.1 --embed_no_wd
For action-free video prediction, remove --load_internal_llm --action_conditioned
.
Install the Metaworld version we used:
pip install git+https://github.com/Farama-Foundation/Metaworld.git@83ac03ca3207c0060112bfc101393ca794ebf1bd
Modify paths in mbrl/cfgs/mbpo_config.yaml
to your own paths (currently only support absolute paths).
python mbrl/train_metaworld_mbpo.py task=plate_slide num_train_frames=100002 demo=true
If you find this project useful, please cite our paper as:
@article{wu2024ivideogpt,
title={iVideoGPT: Interactive VideoGPTs are Scalable World Models},
author={Jialong Wu and Shaofeng Yin and Ningya Feng and Xu He and Dong Li and Jianye Hao and Mingsheng Long},
journal={arXiv preprint arXiv:2405.15223},
year={2024},
}
If you have any question, please contact wujialong0229@gmail.com.
Our codebase is heavily built upon huggingface/diffusers and facebookresearch/drqv2.