cpm.ignite.add_trainer_extension not compatible with DataParallel
msakai opened this issue · comments
Masahiro Sakai commented
When I try to use cpm.ignite.add_trainer_extension
with torch.nn.DataParallel
, it raises AttributeError
.
import numpy as np
import torch
import torch.nn.functional as F
import ignite
from chainer.training import extensions
import chainer_pytorch_migration as cpm
import chainer_pytorch_migration.ignite
model = torch.nn.Linear(3,1)
model = torch.nn.DataParallel(model, device_ids=[0])
X = torch.randn(10, 3)
y = torch.rand(10).to(torch.int64)
dataset = torch.utils.data.TensorDataset(X, y)
train_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=2)
optimizer = torch.optim.Adam(model.parameters())
trainer = ignite.engine.create_supervised_trainer(
model, optimizer, F.nll_loss)
optimizer.target = model
trainer.out = "result"
cpm.ignite.add_trainer_extension(trainer, optimizer, extensions.ExponentialShift(
"lr", rate=1 / 3.0), trigger=(1, 'epoch'))
trainer.run(train_loader, max_epochs=2)
$ python3 test_dataparallel.py
Traceback (most recent call last):
File "test_dataparallel.py", line 25, in <module>
"lr", rate=1 / 3.0), trigger=(1, 'epoch'))
File "/home/user_2114/.local/lib/python3.6/site-packages/chainer_pytorch_migration/ignite/extensions.py", line 83, in add_trainer_extension
engines[id(engine)] = ExtensionTrainerAdapter(engine, optimizer)
File "/home/user_2114/.local/lib/python3.6/site-packages/chainer_pytorch_migration/ignite/extensions.py", line 198, in __init__
self.optimizer = ExtensionOptimizerAdapter(optimizer)
File "/home/user_2114/.local/lib/python3.6/site-packages/chainer_pytorch_migration/ignite/extensions.py", line 340, in __init__
self.target = cpm.TorchModule(optimizer.target)
File "/home/user_2114/.local/lib/python3.6/site-packages/chainer_pytorch_migration/links.py", line 37, in __init__
setattr(self, name, TorchModule(child))
File "/usr/local/lib/python3.6/site-packages/chainer/link.py", line 912, in __setattr__
'cannot register a new link %s: attribute exists' % name)
AttributeError: cannot register a new link module: attribute exists