HolyChen / draw_pytorch

DRAW: A Recurrent Neural Network For Image Generation

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

draw-pytorch

Pytorch implementation of DRAW: A Recurrent Neural Network For Image Generation on the MNIST generation task.

With Attention

Usage

python3 train.py downloads the MNIST dataset to ./data/mnist and train the DRAW model with attention for both reading and writing. After training, the weights files are written to ./save/weights_final.tar and the generated images are written to ./image/.png

python3 generate.py loads wieghts from save/weights_final.tar and generates images

The weights_final.tar file is trained for 50 epoch with minibatch size 64 on GTX 970 GPU.

Reference

About

DRAW: A Recurrent Neural Network For Image Generation


Languages

Language:Python 100.0%