haoningwu3639 / StoryGen

[CVPR 2024] Intelligent Grimm - Open-ended Visual Storytelling via Latent Diffusion Models

Home Page:https://haoningwu3639.github.io/StoryGen_Webpage/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Intelligent Grimm - Open-ended Visual Storytelling via Latent Diffusion Models (CVPR 2024)

This repository contains the official PyTorch implementation of StoryGen: https://arxiv.org/abs/2306.00973/

Some Information

Project Page $\cdot$ Paper $\cdot$ Dataset $\cdot$ Checkpoint

Requirements

A suitable conda environment named storygen can be created and activated with:

conda env create -f environment.yaml
conda activate storygen

Meta Data Preparation

Data from YouTube

We provide the metadata of our StorySalon dataset in ./data/metadata.json. It includes the id, name, url, duration and the keyframe list after filtering the videos.

To download these videos, we recommend to use youtube-dl via:

youtube-dl --write-auto-sub -o 'file\%(title)s.%(ext)s' -f 135 [url]

The keyframes extracted with the following data processing pipeline (step 1) can be filtered according to the keyframe list provided in the metadata to avoid manual selection.

The corresponding masks, story-level description and visual description can be extracted with the following data processing pipeline or downloaded from here.

Data from Open-source Libraries

For the open-source PDF data, you can directly download the frames, corresponding masks, description and narrative from StorySalon.

Data Processing Pipeline

The data processing pipeline includes several necessary steps:

  • Extract the keyframes and their corresponding subtitles;
  • Detect and remove duplicate frames;
  • Segment text, people, and headshots in images; and remove frames that only contain real people;
  • Inpaint the text, headshots and real hands in the frames according to the segmentation mask;
  • (Optional) Use Caption model combined with subtitles to generate a description of each image.

The keyframes and their corresponding subtitles can be extracted via:

python ./data_process/extract.py

The duplicate frames can be detected and removed via:

CUDA_VISIBLE_DEVICES=0 python ./data_process/dup_remove.py

The text, people and headshots can be segmented, and the frames that only contain real people are then removed via:

python ./data_process/yolov7/human_ocr_mask.py

The text, headshots and real hands in the frames can be inpainted with SDM-Inpainting, according to the segmentation mask via:

CUDA_VISIBLE_DEVICES=0 python ./data_process/SDM/inpaint.py

Besides, we also provide the code to get story-level paired image-text samples. We can align the subtitles with visual frames by using Dynamic Time Warping(DTW) algorithm via:

CUDA_VISIBLE_DEVICES=0 python ./data_process/align.py

(Optional) You can use TextBind or MiniGPT-v2 to obtain the caption of each image via:

CUDA_VISIBLE_DEVICES=0 python ./data_process/TextBind/main_caption.py
CUDA_VISIBLE_DEVICES=0 python ./data_process/MiniGPT-v2/main_caption.py

(Discarded) Previous method: You can also use ChatCaptioner to obtain the caption of each image via:

CUDA_VISIBLE_DEVICES=0 python ./data_process/ChatCaptioner/main_caption.py

For a more detailed introduction to the data processing pipeline, please refer to ./data_process/README.md and our paper.

Training

Before training, please download pre-trained StableDiffusion-1.5 from SDM (including vae, scheduler, tokenizer and unet). Then, all the pre-trained checkpoints should be placed into the corresponding location in the folder ./ckpt/stable-diffusion-v1-5/

For Stage 1, pre-train the self-attention layers in SDM for StyleTransfer via:

CUDA_VISIBLE_DEVICES=0 accelerate launch train_StorySalon_stage1.py

For Stage 2, train the Visual-Language Context Module via:

CUDA_VISIBLE_DEVICES=0 accelerate launch train_StorySalon_stage2.py

For replicating the experiments on MS-COCO, train via:

CUDA_VISIBLE_DEVICES=0 accelerate launch train_COCO.py

If you have multiple GPUs to accelerate the training process, you can use:

CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --multi_gpu train_StorySalon_stage2.py

Inference

CUDA_VISIBLE_DEVICES=0 accelerate launch inference.py

TODO

  • Model & Training & Inference Code
  • Dataset Processing Pipeline
  • Meta Data
  • Code Update
  • Release Checkpoints
  • Data Update

License

The code and checkpoints in this repository are under MIT license. The open-source books in the StorySalon dataset come from multiple online open-source libraries (please refer to the Appendix of our paper for more details), and they are all under CC-BY 4.0 license. It should be noted that, for the data extracted from the video data, we only provide YouTube URLs and data processing pipelines, if you wish to use them for commercial purposes, we recommend that you obey the relevant regulations of YouTube.

Citation

If you use this code for your research or project, please cite:

@inproceedings{liu2024intelligent,
  title     = {Intelligent Grimm -- Open-ended Visual Storytelling via Latent Diffusion Models}, 
  author    = {Chang Liu, Haoning Wu, Yujie Zhong, Xiaoyun Zhang, Yanfeng Wang, Weidi Xie},
  booktitle = {The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  year      = {2024},
}

Acknowledgements

Many thanks to the code bases from diffusers and SimpleSDM.

Contact

If you have any questions, please feel free to contact haoningwu3639@gmail.com or liuchang666@sjtu.edu.cn.

About

[CVPR 2024] Intelligent Grimm - Open-ended Visual Storytelling via Latent Diffusion Models

https://haoningwu3639.github.io/StoryGen_Webpage/

License:MIT License


Languages

Language:Python 100.0%Language:Shell 0.0%