SketchKnitter
In this repository, you can find the PyTorch implementation of SketchKnitter: Vectorized Sketch Generation with Diffusion Models, ICLR 2023, Spotlight.
Authors: Qiang Wang, Haoge Deng, Yonggang Qi, Da Li, Yi-Zhe Song. Beijing University of Posts and Telecommunications, Samsung AI Centre Cambridge, University of Surrey.
Abstract: We show vectorized sketch generation can be identified as a reversal of the stroke deformation process. This relationship was established by means of a diffusion model that learns data distributions over the stroke-point locations and pen states of real human sketches. Given randomly scattered stroke-points, sketch generation becomes a process of deformation-based denoising, where the generator rectifies positions of stroke points at each timestep to converge at a recognizable sketch. A key innovation was to embed recognizability into the reverse time diffusion process. It was observed that the estimated noise during the reversal process is strongly correlated with sketch classification accuracy. An auxiliary recurrent neural network (RNN) was consequently used to quantify recognizability during data sampling. It follows that, based on the recognizability scores, a sampling shortcut function can also be devised that renders better quality sketches with fewer sampling steps. Finally it is shown that the model can be easily extended to a conditional generation framework, where given incomplete and unfaithful sketches, it yields one that is more visually appealing and with higher recognizability.
Datasets
Please go to the QuickDraw official website to download the datasets. The class list used in the paper: moon
, airplane
, fish
, umbrella
, train
, spider
, shoe
, apple
, lion
, bus
, you can also replace it with any other category.
The complete dataset in the paper can be downloaded from this link. Due to size limitations, this repo does not contain any datasets, you can also download the all of Quickdraw .npz
datasets from Google Cloud for local use. Each category class is stored in its own file, and contains training/validation/test set sizes of 70000/2500/2500 examples.
In addition to the QuickDraw dataset, you can train the model on any dataset, but please pay attention to organizing the dataset into vector format and packaging it into .npz
file. In the case of less data sets, please pay attention to over-fitting. If you want to create your own dataset, you can follow the official tutorial of SketchRNN.
Installation
The requirements of this repo can be found in requirements.txt.
pip install -r requirements.txt
Train and Inference
Haperparameters
Here is a list of full options for the model:
lr, # learning rate.
log_dir, # save log path.
dropout, # dropout rate.
use_fp16, # whether to use mixed precision training.
ema_rate, # comma-separated list of EMA values
category, # list of category name to be trained.
data_dir, # the data sets path.
use_ddim, # choose whether to use DDIM or DDPM
save_path, # path to save vector results.
pen_break, # determines the experience value of stroke break.
image_size, # the max numbers of datasets.
model_path, # path to save the trained model checkpoint.
class_cond, # whether to use guidance technology.
batch_size, # batch size of training.
emb_channels, # Unet embedding channel numbers.
num_channels, # the numbers of channels in Unet backbone.
out_channels, # output channels in Unet.
save_interval, # saving models interval.
noise_schedule, # the method of adding noise is linear by default.
num_res_blocks, # numbers of resnet blocks in Unet backbone.
diffusion_steps, # diffusion steps in the forward process.
schedule_sampler, # the schedule of sampler.
fp16_scale_growth, # the mixed precision scale growth.
use_scale_shift_norm, # whether to use scale shift norm.
Example Usage:
python train.py --data_dir [/path/to/datasets] \
--lr 1e-4 \
--batch_size 4 \
--use_fp16 False \
--log_dir [/path/to/log] \
--diffusion_steps 100 \
--noise_schedule linear \
--image_size 96 \
--num_channels 96 \
--num_res_blocks 3
python sample.py --model_path [/path/to/save_models] \
--pen_break 0.1 \
--save_path [/path/to/save_results] \
--use_ddim True \
--log_dir [/path/to/save_log] \
--diffusion_steps 100 \
--noise_schedule linear \
--image_size 96 \
--num_channels 96 \
--num_res_blocks 3
Visualization
Because the result file is trained, reasoned and stored in the format of relative coordinate difference vector, if you want to visualize the result, fill the saved .npz
file address into SketchData(dataPath='./datasets_npz')
, and then run the following script, and the result .jpg
file will be saved in ./save_sketch
.
python draw_sketch.py
Evaluation
Please package the results to be evaluated in .npz
format, and provide FID
, IS
, Precision
and Recall
test results.
python evaluations/evaluator.py [/path/to/reference-data] [/path/to/generate-data]
We find that IS
can not accurately describe the distribution of vectorized data, so we use GS
instead of IS
to measure diversity in this paper. The calculation of Geometry Score(GS
) can directly use data in vector format, please go to the official website for instructions.
Results
Simple | FID↓ | GS↓ | Prec↑ | Rec↑ |
---|---|---|---|---|
SketchPix2seq | 13.3 | 7.0 | 0.40 | 0.79 |
SketchHealer | 10.3 | 5.9 | 0.45 | 0.81 |
SketchRNN | 10.8 | 5.4 | 0.44 | 0.82 |
Diff-HW | 13.3 | 6.8 | 0.42 | 0.81 |
SketchODE | 11.5 | 9.4 | 0.48 | 0.74 |
Ours (full 1000 steps) | 6.9 | 3.4 | 0.52 | 0.88 |
Ours (r-Shortcut, S=30) | 7.4 | 3.9 | 0.47 | 0.87 |
Ours (Linear-DDIMs, S=30) | 11.9 | 6.4 | 0.38 | 0.81 |
Ours (Quadratic-DDIMs, S=30) | 12.3 | 6.6 | 0.41 | 0.79 |
Ours (Abs) | 20.7 | 12.1 | 0.18 | 0.55 |
Ours (Point-Shuffle) | 9.5 | 5.3 | 0.35 | 0.72 |
Ours (Stroke-Shuffle) | 8.2 | 3.8 | 0.36 | 0.74 |
Moderate | FID↓ | GS↓ | Prec↑ | Rec↑ |
---|---|---|---|---|
SketchPix2seq | 16.4 | 49.7 | 0.38 | 0.75 |
SketchHealer | 12.9 | 9.8 | 0.39 | 0.79 |
SketchRNN | 13.0 | 11.0 | 0.42 | 0.77 |
Diff-HW | 15.9 | 23.4 | 0.37 | 0.76 |
SketchODE | 18.8 | 29.6 | 0.31 | 0.66 |
Ours (full 1000 steps) | 8.4 | 4.7 | 0.45 | 0.87 |
Ours (r-Shortcut, S=30) | 8.9 | 5.2 | 0.44 | 0.85 |
Ours (Linear-DDIMs, S=30) | 13.3 | 8.8 | 0.36 | 0.78 |
Ours (Quadratic-DDIMs, S=30) | 13.8 | 8.7 | 0.35 | 0.76 |
Ours (Abs) | 23.4 | 64.6 | 0.13 | 0.48 |
Ours (Point-Shuffle) | 11.3 | 7.5 | 0.31 | 0.65 |
Ours (Stroke-Shuffle) | 9.6 | 7.4 | 0.34 | 0.66 |
Complex | FID↓ | GS↓ | Prec↑ | Rec↑ |
---|---|---|---|---|
SketchPix2seq | 18.0 | 73.3 | 0.36 | 0.72 |
SketchHealer | 25.9 | 93.2 | 0.29 | 0.63 |
SketchRNN | 21.4 | 97.6 | 0.35 | 0.72 |
Diff-HW | 18.3 | 64.4 | 0.23 | 0.64 |
SketchODE | 33.5 | 68.1 | 0.20 | 0.58 |
Ours (full 1000 steps) | 9.4 | 5.2 | 0.42 | 0.85 |
Ours (r-Shortcut, S=30) | 10.5 | 6.1 | 0.39 | 0.81 |
Ours (Linear-DDIMs, S=30) | 15.1 | 9.6 | 0.33 | 0.72 |
Ours (Quadratic-DDIMs, S=30) | 15.4 | 9.9 | 0.34 | 0.75 |
Ours (Abs) | 29.4 | 98.9 | 0.10 | 0.39 |
Ours (Point-Shuffle) | 12.4 | 8.1 | 0.20 | 0.61 |
Ours (Stroke-Shuffle) | 10.3 | 7.6 | 0.25 | 0.62 |
Only part of the results are listed here. For more detailed results, please see our paper and supplementary materials.
License
This project is released under the MIT License.
Citation
If you find this repository useful for your research, please use the following.
@inproceedings{wangsketchknitter,
title={SketchKnitter: Vectorized Sketch Generation with Diffusion Models},
author={Wang, Qiang and Deng, Haoge and Qi, Yonggang and Li, Da and Song, Yi-Zhe},
booktitle={The Eleventh International Conference on Learning Representations}
}
Acknowledgements
Contact
If you have any questions about the code, please contact wanqqiang@bupt.edu.cn