Training LinkAsTorchModel-wrapped Chainer model using PyTorch + Ignite?
msakai opened this issue · comments
Migration scenario https://chainer.github.io/migration-guide/#h.y7ovx8en8cp7 says:
It might be easier to port in this order:
- Training script (optimizer / updater / evaluator / ...)
- In order to use PyTorch optimizer to train a Chainer model, you will need cpm.LinkAsTorchModel.
- Dataset / preprocessing
- Dataset is in general compatible between Chainer and PyTorch. This part can be delayed but also should be easy to do.
- Model
- See the mapping of functions/modules below in this document.
But, in 1, cpm.LinkAsTorchModel
does not override Module.forward
, so that wrapped module cannot be trained using PyTorch + Ignite directly (i.e. If it is given to ignite.engine.create_supervised_evaluator
and Engine.run
is invoked, then NotImplementedError
is raised).
Is there a supposed way to train wrapped Chainer model using PyTorch (+ Ignite) in phase 1?
Can you provide a minimal example to reproduce this? thanks!
I originally was trying to port https://github.com/pfnet-research/chainer-formulanet/, but I made smaller example here https://gist.github.com/msakai/eb3cdeb3d39fa8393397bc562210edd7.
$ python train_mnist_pytorch.py
Device: @numpy
# unit: 1000
# Minibatch-size: 100
# epoch: 20
Traceback (most recent call last):
File "train_mnist_pytorch.py", line 179, in <module>
main()
File "train_mnist_pytorch.py", line 175, in main
trainer.run(train_loader, max_epochs=args.epoch)
File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/ignite/engine/engine.py", line 446, in run
self._handle_exception(e)
File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/ignite/engine/engine.py", line 410, in _handle_exception
raise e
File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/ignite/engine/engine.py", line 433, in run
hours, mins, secs = self._run_once_on_dataset()
File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/ignite/engine/engine.py", line 399, in _run_once_on_dataset
self._handle_exception(e)
File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/ignite/engine/engine.py", line 410, in _handle_exception
raise e
File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/ignite/engine/engine.py", line 391, in _run_once_on_dataset
self.state.output = self._process_function(self, batch)
File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/ignite/engine/__init__.py", line 49, in _update
y_pred = model(x)
File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/torch/nn/modules/module.py", line 97, in forward
raise NotImplementedError
NotImplementedError
My attempt is just for familiarizing myself with PyTorch and porting from Chainer.
So it's not a serious issue and I'm not in hurry.
Thank you very much for the prompt response, I believe it is not difficult to fix it so I will take a look asap
The main problem here is that the computational graphs are being mixed and ignite tries to feed the chainer model with torch.Tensor
objects and also backpropagate a loss calculated in pytorch.
We need to add some custom wrapper over ignite to solve this issue, as the training loop for a chainer and a torch model is different. I am working on the PR now.
#6 and #7 are used to solve this issue.
There are some oddities in the code and some custom functions need to be added to.
https://gist.github.com/emcastillo/8be64990940ce23147cc08584bea985b
The main issue is ignite internals and DataLoader assuming a format for the batch, training different from chainer. Ignite accuracy does not work as it uses torch.Tensor notation.
Right now I solve some of this issues manually in the code above, but we need to discuss a clean solution for this that does not impact performance.
Thanks for the PRs and the working code which is very helpful for me.
I think Ignite accuracy can be used if we supply output_transform
.
def output_transform(args):
y_pred, y = args
return cpm.astensor(y_pred.array), cpm.astensor(y)
evaluator = ignite.engine.create_supervised_evaluator(
torched_model,
metrics={
'accuracy': ignite.metrics.Accuracy(output_transform),
'loss': ignite.metrics.Loss(loss_fn),
},
prepare_batch=prepare_batch,
device=torch_device)
Thanks for the pointer!