F.conv2d
Lucien66 opened this issue · comments
Lucien commented
In my network structure, I adopt a dynamic structure through F.conv2d. Should I define the rules of F.conv2d myself
Ligeng Zhu commented
Can you share a code example?
Lucien commented
class Dynamic_conv2d(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=1, dilation=1, groups=1, if_bias=True, K=5, init_weight=False):
super(Dynamic_conv2d, self).__init__()
assert in_planes % groups == 0
self.in_planes = in_planes
self.out_planes = out_planes
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.if_bias = if_bias
self.K = K
self.weight = nn.Parameter(torch.randn(K, out_planes, in_planes//groups, kernel_size, kernel_size), requires_grad=True)
if self.if_bias:
self.bias = nn.Parameter(torch.Tensor(K, out_planes), requires_grad=True)
else:
self.bias = None
if init_weight:
self._initialize_weights()
def _initialize_weights(self):
for i in range(self.K):
nn.init.kaiming_uniform_(self.weight[i])
if self.if_bias:
nn.init.constant_(self.bias[i], 0)
def forward(self, inputs):
x = inputs['x']
softmax_attention = inputs['weights']
batch_size, in_planes, height, width = x.size()
x = x.contiguous().view(1, -1, height, width)
weight = self.weight.view(self.K, -1)
aggregate_weight = torch.mm(softmax_attention, weight).view(-1, self.in_planes, self.kernel_size, self.kernel_size)
if self.bias is not None:
aggregate_bias = torch.mm(softmax_attention, self.bias).view(-1)
output = F.conv2d(x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups*batch_size)
else:
output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups * batch_size)
output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1))
return output
inputs['weights'] come from the output of other networks
Ligeng Zhu commented
THOP
currently does not register counting rule for functions, it only supports defining rules for nn.Modules
You may want to check the an example of matmul
pytorch-OpCounter/thop/vision/basic_hooks.py
Line 140 in d6d8ec0
and its registration here
https://github.com/Lyken17/pytorch-OpCounter/blob/master/thop/profile.py#L53