deep-learning-with-pytorch / dlwpt-code

Code for the book Deep Learning with PyTorch by Eli Stevens, Luca Antiga, and Thomas Viehmann.

Home Page:https://www.manning.com/books/deep-learning-with-pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Undermining batchnorm

OverLordGoldDragon opened this issue · comments

Textbook takes a crack at batchnorm in 8.5.4, citing a paper which "eliminates" need for it, training a 10k layer network without BN by initializing properly, and claiming it helps mainly in convergence speed and isn't a regularizer.

I'm afraid this isn't true; no initialization scheme can replace normalization and confer all its benefits, including better generalization. What unfortunately doesn't seem to make to mainstream is found in research papers cited below.

TL;DR (1) smoother loss landscape guiding optimization away from premature optima and flat regions; (2) weight length-direction decoupling; (3) layer-to-layer scale-invariance. Greater stability permits for dynamic learning rates and other advanced techniques that'd otherwise break down.

Thank you for your feedback. This was not intended to take a crack at BatchNorm. It is an important tool and that is why it is one of the three regularization techniques that we mention in this very brief tour d'horizon.

Note that the citation is a footnote to section 8.5.3, specifically the paragraph on the importance of initialization and serves to support that good initialization is important rather than that batch norm is less important.

In 8.5.4, the comment on batch norm specifically interprets the results of the little experiment in figure 8.13.
If we want to make fine points about the text, I would invite you to carefully observe the "statistical" in the paragraph:

The weight decay and dropout regularizations, which have a more rigorous statistical estimation interpretation as regularization than batch norm, have a much narrower gap between the two accuracies. Batch norm, which serves more as a convergence helper, lets us train the network to nearly 100% training accuracy, so we interpret the first two as regularization.

The benefits you name in the TL-DR seem to be more about convergence of the training than statistical regularization, so it would seem that while the characterization "convergence helper" might not reflect the importance of batch norm (perhaps "almost indispensible convergence helper"?), it seems broadly consistent with being more about the convergence.

The paragraph was written before Daneshmand et al. came out and I hadn't seen the paper, it is very interesting, thank you for linking it! At first glance I would read that, too, as being more about properties of the function class within reach of the training procedure, but that could be my own bias.

I'm not sure I have seen better generalization as a direct benefit of batch norm much - most people I talk to find the move from training to eval one of the more awkward things about BatchNorm. One thing I would find plausible is that BatchNorm and the larger learning rates it allows us helps us reach the interpolating regime (Belkin et al. Reconciling modern machine learning and the group's other papers have several very interesting thoughts there), and that is how it helps generalization, but this is speculation.

All this is not to say that the section couldn't have been more written clearly. If you read it as trying to undermine BatchNorm rather than highlighting the (in my view) somewhat different nature of regularization we aim at with weight decay and dropout on one hand and batch norm on the other, it hasn't conveyed its message clearly enough.

@t-vi That's the thing, if an experienced user gets the wrong message, a beginner is likely take this as "batchnorm just speedy". Me posting this is solely per

so we interpret the first two as regularization

Now, about BN and my TL;DR; much research has shown that in deep NN's we have multiple equally good local optima (a whole continuum in fact) that aren't far off the global optimum. But global optimum isn't even our goal - the train global optimum is likely to yield horrible test performance (unless train set is tremendous and resembles test set). So local optima are basically always preferable. But not all local train optima yield equally good test performance. What papers I've linked and others have shown is, by making the loss landscape smooth in the way BN does, we reduce the likelyhood of optimization yielding poor test local optima.

So the "statistical estimation" claim too is mistaken - to contrary, BN has a much stronger regularizing effect in this regard than weight decay does, as weight decay simply biases the optimizer away from some points on the loss surface, rather than changing the loss surface itself.

And again, consider cyclic learning rates that include large jumps in LR between epochs - I dare anyone to pull this off on a variety of deep networks without normalization. Yet CLR has shown improved generalization - so norming also nicely combines with other regularizers.

I guess the main thing we're disagreeing on is what is statistical regularization and what is building things into the Ansatz space / loss geometry, what is a inherent benefit and what is an enablement of benefits. This philosophical nuance probably won't be resolved here.

Theoretical justifications may differ but empirical evidence less so on BN's benefits to generalization. Also I slipped a bit on weight decay, had AdamW's loss-decoupled approach in mind - but another topic. Feel free to also close the other thread.