lilujunai / DCP-GAN

[CVPR 2024] Diversity-aware Channel Pruning for StyleGAN Compression

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[CVPR 2024] Diversity-aware Channel Pruning for StyleGAN Compression

imgs


Usage

To test our code, please follow these steps:

  1. Setup
  2. Pruning
  3. Train
  4. Inference
  5. Evaluation

Pre-trained weights

If you want to test our lightweight model only, please download the pre-trained model from this link and proceed to the Inference step.

Note that some of our provided weights are larger than the teacher model. This is because our checkpoint contains the generator, generator_ema, and discriminator in a single .pt file.

Setup

Our codebase is built on (xuguodong03/StyleKD and rosinality/stylegan2-pytorch) and has similar architecture and dependencies.

I tested the code in the pytorch/pytorch:1.8.1-cuda11.1-cudnn8-devel Docker image. Then, run:

pip install -r requirements.txt

Pruning

run:

python prune.py
  • Pruning ratio ($p_r$) is controlled by the --remove_ratio parameter. (default: 0.7)
  • Strength of perturbation ($\alpha$) is controlled by the --edit_strength parameter. (default: 5.0)
  • The number of perturbations for each latent vector ($N$) is controlled through the --n_direction parameter. (default: 10)

We use a total of 5000 samples for score calculation, but the results are similar when using over 1000 samples.

Train

This execution assumes a 4-GPU setting.

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=8001 \
    distributed_train.py --name batch16_run0 \
    --load 0 --load_style 1 --g_step 0 \
    --kd_l1_lambda 3 --kd_lpips_lambda 3 --kd_simi_lambda 30 \
    --batch 4 --worker 8 \
    --teacher_ckpt ./Model/teacher_model/256px_full_size.pt \
    --student_ckpt ./Model/pruned_model/dcp_0.7_256px_a5.0_n10_t1.00_model.pt \
    --path /dataset/ffhq \
    --train_mode ffhq

Inference

Download the weights from this link or train yourself, and run:

python generate.py --ckpt ./Model/student_model/dcp_ffhq256.pt

Evaluation

For precision and recall (P&R), I utilized code from the NVlabs/stylegan2-ada-pytorch repository. Please refer to this repository for further details.

L1 distance for teacher and student

python get_recon.py \
    --t ./Model/teacher_model/256px_full_size.pt \
    --s ./Model/student_model/dcp_ffhq256.pt

FID

We utilize pre-calculated features for FID evaluation, accessible via this link. Please copy the selected features to the "./Evaluation" directory before evaluation (FFHQ256 is already in the repository).

python get_fid.py --ckpt ./Model/student_model/dcp_ffhq256.pt --data_type ffhq # ffhq dataset
python get_fid.py --ckpt ./Model/student_model/dcp_horse.pt --data_type horse # lsun horse dataset

Citation

I will update bibtex soon.

Commin soon!

About

[CVPR 2024] Diversity-aware Channel Pruning for StyleGAN Compression


Languages

Language:Python 82.5%Language:Cuda 12.5%Language:C++ 4.6%Language:C 0.4%