chainer / chainer-pytorch-migration

Chainer/PyTorch Migration Library

Home Page:https://chainer.github.io/migration-guide/#h.9wc1iaeyqb2c

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

cpm.ignite.add_trainer_extension not compatible with DataParallel

msakai opened this issue · comments

When I try to use cpm.ignite.add_trainer_extension with torch.nn.DataParallel, it raises AttributeError.

import numpy as np
import torch
import torch.nn.functional as F
import ignite

from chainer.training import extensions
import chainer_pytorch_migration as cpm
import chainer_pytorch_migration.ignite

model = torch.nn.Linear(3,1)
model = torch.nn.DataParallel(model, device_ids=[0])

X = torch.randn(10, 3)
y = torch.rand(10).to(torch.int64)
dataset = torch.utils.data.TensorDataset(X, y)
train_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=2)

optimizer = torch.optim.Adam(model.parameters())
trainer = ignite.engine.create_supervised_trainer(
    model, optimizer, F.nll_loss)

optimizer.target = model
trainer.out = "result"
cpm.ignite.add_trainer_extension(trainer, optimizer, extensions.ExponentialShift(
    "lr", rate=1 / 3.0), trigger=(1, 'epoch'))

trainer.run(train_loader, max_epochs=2)
$ python3 test_dataparallel.py 
Traceback (most recent call last):
  File "test_dataparallel.py", line 25, in <module>
    "lr", rate=1 / 3.0), trigger=(1, 'epoch'))
  File "/home/user_2114/.local/lib/python3.6/site-packages/chainer_pytorch_migration/ignite/extensions.py", line 83, in add_trainer_extension
    engines[id(engine)] = ExtensionTrainerAdapter(engine, optimizer)
  File "/home/user_2114/.local/lib/python3.6/site-packages/chainer_pytorch_migration/ignite/extensions.py", line 198, in __init__
    self.optimizer = ExtensionOptimizerAdapter(optimizer)
  File "/home/user_2114/.local/lib/python3.6/site-packages/chainer_pytorch_migration/ignite/extensions.py", line 340, in __init__
    self.target = cpm.TorchModule(optimizer.target)
  File "/home/user_2114/.local/lib/python3.6/site-packages/chainer_pytorch_migration/links.py", line 37, in __init__
    setattr(self, name, TorchModule(child))
  File "/usr/local/lib/python3.6/site-packages/chainer/link.py", line 912, in __setattr__
    'cannot register a new link %s: attribute exists' % name)
AttributeError: cannot register a new link module: attribute exists