ameroyer / glow_jax

An implementation of the Glow generative model in jax and flax

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Glow generative model in jax

An implementation of the Glow generative model in jax, and using the high-level API flax. Glow is a reversible generative model, based on the variational auto-encoder framework with normalizing flows. The notebook can also be found on kaggle, where it was trained on a subset of the aligned CelebA dataset.

Setup

Dependencies

pip install jax jaxlib
pip install flax

Sample from the model

Random samples can be generated as follows; Here for instance for generating 16 samples with sampling temperature 0.7 and setting the random seed to 0:

python3 sample.py 16 -t 0.7 -s 0 --model_path [path]

Example

A pretrained model can be found in the kaggle notebook's outputs.

Note: The model was only trained for roughly 13 epochs due to computation limits. Compared to the original model, it also uses ashallower flow (K = 16 flow steps per scale)

Example results - training evolution

t=0.85

Example results - sampling

t=0.85

t=0.7

Example results - linear interpolation

Linear interpolation results

About

An implementation of the Glow generative model in jax and flax

License:GNU General Public License v3.0


Languages

Language:Jupyter Notebook 99.7%Language:Python 0.3%