mit-han-lab / gan-compression

[CVPR 2020] GAN Compression: Efficient Architectures for Interactive Conditional GANs

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Distill Problem

saijo0404 opened this issue · comments

commented

I tried to train a pix2pix model on the edges2shoes-r dataset using train_full.sh.

#!/usr/bin/env bash
python distill.py --dataroot database/edges2shoes-r \
  --distiller resnet \
  --log_dir logs/pix2pix/edges2shoes-r/distill \
  --batch_size 4 \
  --restore_teacher_G_path logs/pix2pix/edges2shoes-r/train/checkpoints/latest_net_G.pth \
  --restore_pretrained_G_path logs/pix2pix/edges2shoes-r/train/checkpoints/latest_net_G.pth \
  --pretrained_netG resnet_9blocks \
  --teacher_netG resnet_9blocks \
  --student_netG resnet_9blocks \
  --restore_D_path logs/pix2pix/edges2shoes-r/train/checkpoints/latest_net_D.pth \
  --real_stat_path real_stat/edges2shoes-r_B.npz \
  --meta_path datasets/metas/edges2shoes-r/train1.meta 

After training, I used this bash, but I get an AssertionError.
In weight_transfer.py line 14, in transfer_Conv2d
assert isinstance(m1, nn.Conv2d) and isinstance(m2, (nn.Conv2d, SuperConv2d))
How can I solve this problem?

Could you provide some more information? What is the type of your m1 and m2?

commented

I try to print m1 and m2 type, the result look like this.

distiller [ResnetDistiller] was created
Load network at logs/pix2pix/edges2shoes-r/train/checkpoints/latest_net_G.pth
isinstance(netA, nn.DataParallel):  False
isinstance(netB, nn.DataParallel):  False
isinstance(m1, nn.Conv2d) and isinstance(m2, (nn.Conv2d, SuperConv2d)):  True
m1 type:  <class 'torch.nn.modules.conv.Conv2d'>
m2 type:  <class 'torch.nn.modules.conv.Conv2d'>
isinstance(m1, nn.Conv2d) and isinstance(m2, (nn.Conv2d, SuperConv2d)):  True
m1 type:  <class 'torch.nn.modules.conv.Conv2d'>
m2 type:  <class 'torch.nn.modules.conv.Conv2d'>
isinstance(m1, nn.Conv2d) and isinstance(m2, (nn.Conv2d, SuperConv2d)):  True
m1 type:  <class 'torch.nn.modules.conv.Conv2d'>
m2 type:  <class 'torch.nn.modules.conv.Conv2d'>
isinstance(m1, nn.Conv2d) and isinstance(m2, (nn.Conv2d, SuperConv2d)):  False
m1 type:  <class 'models.modules.resnet_architecture.resnet_generator.ResnetBlock'>
m2 type:  <class 'models.modules.resnet_architecture.resnet_generator.ResnetBlock'>

I see. This is a minor bug in weight_transfer.py because of a typo. I've fixed it in this commit. Could you pull the latest commit and try again?

I will close this issue. Let me know if there are some further issues!