facebookresearch / ConvNeXt

Code release for ConvNeXt model

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Large Variance in feature maps

nightsnack opened this issue · comments

Hi, I find the variance in each stage's feature maps grows up quickly with the depth goes up. Although exploding variance is a long last problem starting from resnet residual connection, the variance explosion in convnext seems quite serious. See the pic below. I print the variance of x in each stage forward function and got this.
image

Any idea of this? Cause it could soon go beyond fp16’s upper limit when the model goes deeper.

I notice you use trunc_normal to initialize the Conv2d module here, line104 which is seldom seen in CNN initialization. Truncate normal is often seen in linear layer's initialization in transformer, but in traditional CNN, people often use kaiming normal to stabilize the output.

Are there any connections of this trun_normal and the exploding variance above?

Thanks for sharing your findings. Very interesting that the variance explodes with depth. Can you share your exact formula for calculating the variance?

We tried using PyTorch's default initialization for Conv2d layers but observed nearly no differences in training curves so we are not sure whether those are related.

exact formula for calculating the variance?

Here it is
image

Yes, kaiming init has no inference to the exploding variance problem, but I do find something related. If you change the prenorm of downsample to post norm(like swin v2 did), the variance can reduce to 10. I'm not sure if this post norm will bring any accuracy degradation.

Is large variance necessarily bad? Can't it be interpreted as learning a more diverse representations, so more variance?

commented

Im using convnext-tiny as the backbone to do a multi-instance-learning project, I visualize the a bag's feature (1000 instances) as a greyscale image(normalize values to 0~255), the shape of this image is (1000, 768)。 Grey part's value is close to zero, only some channels have negative or positive values (black & white), this is so different from resnet(image below)。
image
image

commented

Does this mean that the final classification is only based on several black&white channels, and grey channels don't make any contribution?

Interesting. Are you sure this is properly normalized independently across architectures? ConvNeXt might have a different absolute scale in its feature maps, and some times the visualization might not tell the full story.

commented

Interesting. Are you sure this is properly normalized independently across architectures? ConvNeXt might have a different absolute scale in its feature maps, and some times the visualization might not tell the full story.

The output values of self.forward_features(self, x) is between -10 and 10。 I concatenate them vertically and normalize to 0~255 to save as an image. This is what the original outputs look like:
image