chainer / chainer-pytorch-migration

Chainer/PyTorch Migration Library

Home Page:https://chainer.github.io/migration-guide/#h.9wc1iaeyqb2c

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Resumption from snapshot taken in the middle of the epoch

msakai opened this issue · comments

It seems that, if training is resumed from snapshot taken in the middle of the epoch, the rest of iterations of the epoch is not executed.

In the following example, I expect ITERATION_COMPLETED is printed 9 times after resumption but none is printed.

import torch
import torch.nn.functional as F
import ignite
import ignite.contrib.handlers
from chainer.training import extensions
import chainer_pytorch_migration as cpm
import chainer_pytorch_migration.ignite


device = torch.device('cpu')

model = torch.nn.Linear(3, 1).to(device)

X = torch.randn(100, 3)
y = torch.randint(high=1, size=(100,)).to(torch.int64)
dataset = torch.utils.data.TensorDataset(X, y)


def create_trainer():
    optimizer = torch.optim.Adam(model.parameters())
    trainer = ignite.engine.create_supervised_trainer(
        model, optimizer, F.nll_loss, device=device)

    @trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
    def print_ITERATION_COMPLETED(engine):
        print("ITERATION_COMPLETED", flush=True)

    optimizer.target = model
    trainer.out = "result"
    snapshot = extensions.snapshot(filename='snapshot_iteration-{.updater.iteration}')
    cpm.ignite.add_trainer_extension(trainer, optimizer, snapshot, trigger=(1, 'iteration'))

    return trainer, optimizer


trainer, optimizer = create_trainer()
train_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=10)
trainer.run(train_loader, max_epochs=1)

trainer, optimizer = create_trainer()
cpm.ignite.load_chainer_snapshot(trainer, optimizer, "result/snapshot_iteration-1")
print("resumed", flush=True)
train_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=10)
trainer.run(train_loader, max_epochs=1)
$ python test_resume.py
ITERATION_COMPLETED
ITERATION_COMPLETED
ITERATION_COMPLETED
ITERATION_COMPLETED
ITERATION_COMPLETED
ITERATION_COMPLETED
ITERATION_COMPLETED
ITERATION_COMPLETED
ITERATION_COMPLETED
ITERATION_COMPLETED
resumed

Sent a PR to fix this.