pytorch / contrib

Implementations of ideas from recent papers

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

The position of bn_update?

2033329616 opened this issue · comments

If model has batch normalization layers, where should I use the bn_update()?

for _ in range(100):
     opt.zero_grad()
     loss_fn(model(input), target).backward()
     opt.step()
opt.swap_swa_sgd()
opt.bn_update(train_loader, model)

Is this setting correct?

Or the following case is right?

for _ in range(100):
      opt.zero_grad() 
      loss_fn(model(input), target).backward()`
      opt.step()
opt.bn_update(train_loader, model)
opt.swap_swa_sgd()