liufeng1990 / gradflow-check

Check gradient flow in Pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Gradient flow check in Pytorch

Check that the gradient flow is proper in the network by recording the average gradients per layer in every training iteration and then plotting them at the end. If the average gradients are zero in the initial layers of the network then probably your network is too deep for the gradient to flow.

Usage

loss = self.criterion(outputs, labels)  
loss.backward()
plot_grad_flow(model.named_parameters()) # version 1
# OR
plot_grad_flow_v2(model.named_parameters()) # version 2

Result

Bad gradient flow:

Bad gradient

Good gradient flow:

Good gradient

Repo based on this pytorch discuss post.

About

Check gradient flow in Pytorch


Languages

Language:Python 100.0%