Resumption from snapshot taken in the middle of the epoch
msakai opened this issue · comments
Masahiro Sakai commented
It seems that, if training is resumed from snapshot taken in the middle of the epoch, the rest of iterations of the epoch is not executed.
In the following example, I expect ITERATION_COMPLETED
is printed 9 times after resumption but none is printed.
import torch
import torch.nn.functional as F
import ignite
import ignite.contrib.handlers
from chainer.training import extensions
import chainer_pytorch_migration as cpm
import chainer_pytorch_migration.ignite
device = torch.device('cpu')
model = torch.nn.Linear(3, 1).to(device)
X = torch.randn(100, 3)
y = torch.randint(high=1, size=(100,)).to(torch.int64)
dataset = torch.utils.data.TensorDataset(X, y)
def create_trainer():
optimizer = torch.optim.Adam(model.parameters())
trainer = ignite.engine.create_supervised_trainer(
model, optimizer, F.nll_loss, device=device)
@trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
def print_ITERATION_COMPLETED(engine):
print("ITERATION_COMPLETED", flush=True)
optimizer.target = model
trainer.out = "result"
snapshot = extensions.snapshot(filename='snapshot_iteration-{.updater.iteration}')
cpm.ignite.add_trainer_extension(trainer, optimizer, snapshot, trigger=(1, 'iteration'))
return trainer, optimizer
trainer, optimizer = create_trainer()
train_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=10)
trainer.run(train_loader, max_epochs=1)
trainer, optimizer = create_trainer()
cpm.ignite.load_chainer_snapshot(trainer, optimizer, "result/snapshot_iteration-1")
print("resumed", flush=True)
train_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=10)
trainer.run(train_loader, max_epochs=1)
$ python test_resume.py
ITERATION_COMPLETED
ITERATION_COMPLETED
ITERATION_COMPLETED
ITERATION_COMPLETED
ITERATION_COMPLETED
ITERATION_COMPLETED
ITERATION_COMPLETED
ITERATION_COMPLETED
ITERATION_COMPLETED
ITERATION_COMPLETED
resumed
emcastillo commented
Sent a PR to fix this.