KaiyangZhou / mixstyle-release

Domain Generalization with MixStyle (ICLR'21)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

MixStyle layers for DigitsDG dataset

zwenyu opened this issue · comments

Hi, for the CNN used in DigitsDG experiments in the paper, on which layers is MixStyle applied? Is the network also trained for 150 epochs? Thanks.

The network architecture looks like this

class Convolution(nn.Module):

    def __init__(self, c_in, c_out):
        super().__init__()
        self.conv = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1)
        self.relu = nn.ReLU(True)

    def forward(self, x):
        return self.relu(self.conv(x))


class ConvNet(Backbone):
    """CNN + MixStyle."""

    def __init__(self, c_hidden=64, mixstyle_layers=[]):
        super().__init__()
        self.conv1 = Convolution(3, c_hidden)
        self.conv2 = Convolution(c_hidden, c_hidden)
        self.conv3 = Convolution(c_hidden, c_hidden)
        self.conv4 = Convolution(c_hidden, c_hidden)

        self.mixstyle = MixStyle(p=0.5, alpha=0.1)
        self.mixstyle_layers = mixstyle_layers
        print('Insert MixStyle after the following layers: {}'.format(mixstyle_layers))

        self._out_features = 2**2 * c_hidden

    def _check_input(self, x):
        H, W = x.shape[2:]
        assert H == 32 and W == 32, \
            'Input to network must be 32x32, ' \
            'but got {}x{}'.format(H, W)

    def forward(self, x):
        self._check_input(x)
        x = self.conv1(x)
        x = F.max_pool2d(x, 2)
        if 'conv1' in self.mixstyle_layers:
            x = self.mixstyle(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        if 'conv2' in self.mixstyle_layers:
            x = self.mixstyle(x)
        x = self.conv3(x)
        x = F.max_pool2d(x, 2)
        if 'conv3' in self.mixstyle_layers:
            x = self.mixstyle(x)
        x = self.conv4(x)
        x = F.max_pool2d(x, 2)
        return x.view(x.size(0), -1)

and the mixstyle layer is inserted only after conv1.

The training parameters follow this https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/configs/trainers/dg/vanilla/digits_dg.yaml

Thanks!