About the Star Opt
tarv33 opened this issue · comments
JC commented
if replace two conv with one conv, it might be a litte faster
class Block(nn.Module):
def __init__(self, dim, mlp_ratio=3, drop_path=0.0):
super().__init__()
self.dwconv = ConvBN(dim, dim, 7, 1, (7 - 1) // 2, groups=dim, with_bn=True)
# self.f1 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False)
# self.f2 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False)
# if replace two conv with one conv, it might be a litte faster
self.f = ConvBN(dim, mlp_ratio * dim * 2, 1, groups=2, with_bn=False)
self.g = ConvBN(mlp_ratio * dim, dim, 1, with_bn=True)
self.dwconv2 = ConvBN(dim, dim, 7, 1, (7 - 1) // 2, groups=dim, with_bn=False)
self.act = nn.ReLU6()
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.c = mlp_ratio * dim
def forward(self, x):
input = x
x = self.dwconv(x)
# x1, x2 = self.f1(x), self.f2(x)
gx = self.f(x)
x1, x2 = torch.split(gx, self.c, dim=1)
x = self.act(x1) * x2
x = self.dwconv2(self.g(x))
x = input + self.drop_path(x)
return x