Difference of spectral norm conv2d in Dblock/DBlockOptimized
Leiwx52 opened this issue · comments
Hi~ thanks for your great contribution!
I have a question for the implementation of Dblock
and DBlockOptimized
. When I was trying to visualize the spectral norm of some layers, I noticed that you have two version of implementation of SN: class SNConv2d
in torch_mimicry/modules/spectral_norm.py
and function SNConv2d
in torch_mimicry/modules/layers.py
. If it is true, in Dblock
and DBlockOptimized
, class SNConv2d
with spectral_norm=True
by default. However, I feel little confused when i run the following code:
-
Simply the following lines:
from torch_mimicry.modules.resblocks import DBlockOptimized, DBlock, GBlock if __name__ == "__main__": block = DBlockOptimized(3,128) print(block)
the output was:
DBlockOptimized( (c1): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (c2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (c_sc): Conv2d(3, 128, kernel_size=(1, 1), stride=(1, 1)) (activation): ReLU(inplace=True) )
I checked
Conv2d
in that block and it seems that it implemented SN using pytorch official spectral norm since theConv2d
has attributeweight_orig
, which origins from pytorch official implementation. -
Copied the whole file of torch_mimicry/modules/resblocks.py and add the following lines:
###### ... contents of resblock.py ... ###### ## adding the following lines if __name__ == "__main__": block = DBlockOptimized(3,128) print(block)
the output was:
DBlockOptimized( (c1): SNConv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (c2): SNConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (c_sc): SNConv2d(3, 128, kernel_size=(1, 1), stride=(1, 1)) (activation): ReLU(inplace=True) )
here the SN is implemented as in class
SNConv2d
by the author.
Since class SNConv2d
has a more convenient way to inspect the spectral norm(it has related function to do this directly), I wonder what caused this difference. Also when I call SNGANDiscriminator32
and print its modules, results show that all the SN conv2d is implemented by pytorch official SN function/wrapper. If you've already figured it out, plz tell me how to switch to class SNConv2d
.
THANKS!