MixStyle layers for DigitsDG dataset
zwenyu opened this issue · comments
Wenyu Zhang commented
Hi, for the CNN used in DigitsDG experiments in the paper, on which layers is MixStyle applied? Is the network also trained for 150 epochs? Thanks.
Kaiyang commented
The network architecture looks like this
class Convolution(nn.Module):
def __init__(self, c_in, c_out):
super().__init__()
self.conv = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1)
self.relu = nn.ReLU(True)
def forward(self, x):
return self.relu(self.conv(x))
class ConvNet(Backbone):
"""CNN + MixStyle."""
def __init__(self, c_hidden=64, mixstyle_layers=[]):
super().__init__()
self.conv1 = Convolution(3, c_hidden)
self.conv2 = Convolution(c_hidden, c_hidden)
self.conv3 = Convolution(c_hidden, c_hidden)
self.conv4 = Convolution(c_hidden, c_hidden)
self.mixstyle = MixStyle(p=0.5, alpha=0.1)
self.mixstyle_layers = mixstyle_layers
print('Insert MixStyle after the following layers: {}'.format(mixstyle_layers))
self._out_features = 2**2 * c_hidden
def _check_input(self, x):
H, W = x.shape[2:]
assert H == 32 and W == 32, \
'Input to network must be 32x32, ' \
'but got {}x{}'.format(H, W)
def forward(self, x):
self._check_input(x)
x = self.conv1(x)
x = F.max_pool2d(x, 2)
if 'conv1' in self.mixstyle_layers:
x = self.mixstyle(x)
x = self.conv2(x)
x = F.max_pool2d(x, 2)
if 'conv2' in self.mixstyle_layers:
x = self.mixstyle(x)
x = self.conv3(x)
x = F.max_pool2d(x, 2)
if 'conv3' in self.mixstyle_layers:
x = self.mixstyle(x)
x = self.conv4(x)
x = F.max_pool2d(x, 2)
return x.view(x.size(0), -1)
and the mixstyle layer is inserted only after conv1
.
The training parameters follow this https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/configs/trainers/dg/vanilla/digits_dg.yaml
Wenyu Zhang commented
Thanks!