This repository contains the Python implementation for our generative clustering method VaDE.
Further details about VaDE can be found in our paper:
- Python-3.4.4
- keras-1.1.0
- scikit-learn-1.17.1
Replace keras/engine/training.py
by training.py
in this repository!!
(The modification version of keras/engine/training.py
enables the simultaneous updating of the gmm parameters and the network parameters in our model.)
- To train the VaDE model on the MNIST, Reuters, HHAR datasets:
python ./VaDE.py db
db can be one of mnist,reuters10k,har.
- To achieve the 94.46% clustering accuracy on the MNIST dataset and generate the class-specified digits (Note that: unlike Conditional GAN, we do not use any supervised information during training):
python ./VaDE_test_mnist.py
- To achieve the 79.38% clustering accuracy on the Reuters(685K) dataset:
cd $VaDE_ROOT/dataset/reuters
./get_data.sh
cd $VaDE_ROOT
python ./VaDE_test_reuters_all.py
Note: the data preprocessing code for the Reuters dataset is taken from (https://github.com/piiswrong/dec).
(DCGAN-like network architecture)
- Attribute-conditioned generation (sampled from each cluster) without using any supervised information
1-6 rows: 1.black/short hair, man; 2.black/long hair, woman; 3.gold/long hair, woman; 4.bald, sunglasses, man; 5.left side face, woman; 6.right side face, woman;