talcron / frame-prediction-pytorch

PyTorch implementation of WGAN-GP-based video generation. Includes functionality for measuring Frechet Video Distance and implementing recent research improvements of WGAN-GP. Read paper at https://github.com/talcron/frame-prediction-pytorch/blob/media/paper.pdf

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

WGAN Video Generation

Video generation is a rapidly growing field within deep learning. A highly accurate video generation model has potential applications ranging from robotics to film production. Such a model would contain a powerful semantic representation of objects, backgrounds, and scene dynamics. We build upon a generative model, iVGAN, by incorporating recently proposed improvements. We compare results for methods including spectral normalization, drift penalties, and modified gradient penalty losses (one-sided and zero-centered). We find that while spectral norm and drift penalties appear to offer benefits by controlling the scale of gradients, zero-centered gradient penalties yield the most realistic generated videos. For more information, refer to the paper.

The fvd module incorporates code from Google Research, licensed under the Apache 2.0 license, the text of which can be found in the LICENSE file.

The model is implemented in PyTorch and adapted from the iVGAN Tensorflow implementation.

Generated Sample 1 Generated Sample 3

Setup

Requires:

pytorch
torchvision
tensorflow-gpu
tensorflow-hub
tensorflow-gan
comet_ml
matplotlib
opencv

Environment

Conda

We recommend using a conda environment. Once you have conda installed (look at miniconda). Install with:

conda create env -f environment.yml
conda activate torch

Docker

Get the Docker image used to run this:

docker pull ianpegg9/torch:tf

The included Dockerfile contains the spec for this image.

CometML

To use CometML to log experimental results, create a .comet.config file with contents like this:

[comet]
api_key=MYKEY
project_name=frame-prediction-pytorch
workspace=username

Otherwise, pass the argument --exp-disable to the script.

Data

Our dataloaders support UCF-101, UCF-sports, and tinyvideo.

Tinyvideo

Tinyvideo

We recommend downloading 'Beach only' or 'Golf only'. Regardless of what you download, it should be saved in a tarball and saved in a directory named 'tinyvideo'.

Pre-process everything to 64x64 to save space with this huge dataset. Use a machine with many processors, or it will take several days to process the entire tarball.

python scripts/process_tinyvideo.py <path-to-data-tarball>

Create the index file:

find /path/to/tinyvideo -name *.jpg > /path/to/tinyvideo/index.txt

UCF-101 and UCF sports

UCF

Download and unzip the data file.

Pre-process everything for easy loading. We don't delete the original videos with this pre-processor.

bash scripts/process_ucf.sh /path/to/ucf-data

Create the index file:

find /path/to/ucf-data -name *.mjpeg > /path/to/ucf-data/index.txt

Usage

The sections below describe the main use cases. For further options, use python main_train.py --help

Train

--root-dir /path/to/data/dir

--index-file root-dir/index.txt

--save-dir /where/you/save/your/results

--exp-name <tag or tags for CometML>

Continue training

--resume /path/to/checkpoint.model

Optional:

--exp-key <comet-ml experiment key> to continue a CometML experiment.

Evaluate

Evaluate instead of training.

--evaluate

Real data to compare to:

--index-file root-dir/index.txt

Checkpoint to load:

--resume /path/to/checkpoint.model

About

PyTorch implementation of WGAN-GP-based video generation. Includes functionality for measuring Frechet Video Distance and implementing recent research improvements of WGAN-GP. Read paper at https://github.com/talcron/frame-prediction-pytorch/blob/media/paper.pdf

License:Apache License 2.0


Languages

Language:Python 74.2%Language:Shell 24.2%Language:Dockerfile 1.7%