EBGU / Rectified_Flow

A replication of rectified flow paper with the Oxford Flowers dataset

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Rectified Flow

A replication of Rectified Flow paper with PyTorch and U-ViT.

Training

To train a new model, you can modify the yaml file and:

python multi_gpu_trainer.py example

Training data of Oxford Flowers should be split manually, and you can find the numpy version of their labels in this repo.

Inference

To run inference, please download my pretrained weight:

python sample_img.py --device "cuda:0" --load "last" --SavedDir tmp/ --ExpConfig example/example.yaml --n_sqrt 16 --steps 200

or use an ODE solver: pip install torchdiffeq python sample_img_ODESolver.py --device "cuda:0" --load "last" --SavedDir tmp/ --ExpConfig example/example.yaml --n_sqrt 16 --rtol 0.001

The inference process is controled by 6 parameters :

"device", usually 'cuda:0' ;

"load", best epoch or last epoch;

"SavedDir", where to save images;

"ExpConfig", the yaml file of your experiments;

"n_sqrt", you will get N2 samples for each class;

"steps", n steps for sampling, in my experiment, 200 is a good choice;

"rtol", acceptable relative error per step, 1e-3 is good enough.

The result should looks like the welcoming images.

Image interpolation

interpolation

python image_interpolation.py --device "cuda:0" --load "last" --SavedDir tmp/ --ExpConfig example/example.yaml --input_image images/image1.jpg --target_image images/image2.jpg --rtol 0.0001 --mix_depth -0.02 --spherical True

This function is experimental and currently does not work well!

Enjoy!

About

A replication of rectified flow paper with the Oxford Flowers dataset

License:MIT License


Languages

Language:Python 100.0%