Distill Problem
saijo0404 opened this issue · comments
I tried to train a pix2pix model on the edges2shoes-r dataset using train_full.sh.
#!/usr/bin/env bash
python distill.py --dataroot database/edges2shoes-r \
--distiller resnet \
--log_dir logs/pix2pix/edges2shoes-r/distill \
--batch_size 4 \
--restore_teacher_G_path logs/pix2pix/edges2shoes-r/train/checkpoints/latest_net_G.pth \
--restore_pretrained_G_path logs/pix2pix/edges2shoes-r/train/checkpoints/latest_net_G.pth \
--pretrained_netG resnet_9blocks \
--teacher_netG resnet_9blocks \
--student_netG resnet_9blocks \
--restore_D_path logs/pix2pix/edges2shoes-r/train/checkpoints/latest_net_D.pth \
--real_stat_path real_stat/edges2shoes-r_B.npz \
--meta_path datasets/metas/edges2shoes-r/train1.meta
After training, I used this bash, but I get an AssertionError.
In weight_transfer.py line 14, in transfer_Conv2d
assert isinstance(m1, nn.Conv2d) and isinstance(m2, (nn.Conv2d, SuperConv2d))
How can I solve this problem?
Could you provide some more information? What is the type of your m1
and m2
?
I try to print m1 and m2 type, the result look like this.
distiller [ResnetDistiller] was created
Load network at logs/pix2pix/edges2shoes-r/train/checkpoints/latest_net_G.pth
isinstance(netA, nn.DataParallel): False
isinstance(netB, nn.DataParallel): False
isinstance(m1, nn.Conv2d) and isinstance(m2, (nn.Conv2d, SuperConv2d)): True
m1 type: <class 'torch.nn.modules.conv.Conv2d'>
m2 type: <class 'torch.nn.modules.conv.Conv2d'>
isinstance(m1, nn.Conv2d) and isinstance(m2, (nn.Conv2d, SuperConv2d)): True
m1 type: <class 'torch.nn.modules.conv.Conv2d'>
m2 type: <class 'torch.nn.modules.conv.Conv2d'>
isinstance(m1, nn.Conv2d) and isinstance(m2, (nn.Conv2d, SuperConv2d)): True
m1 type: <class 'torch.nn.modules.conv.Conv2d'>
m2 type: <class 'torch.nn.modules.conv.Conv2d'>
isinstance(m1, nn.Conv2d) and isinstance(m2, (nn.Conv2d, SuperConv2d)): False
m1 type: <class 'models.modules.resnet_architecture.resnet_generator.ResnetBlock'>
m2 type: <class 'models.modules.resnet_architecture.resnet_generator.ResnetBlock'>
I see. This is a minor bug in weight_transfer.py
because of a typo. I've fixed it in this commit. Could you pull the latest commit and try again?
I will close this issue. Let me know if there are some further issues!