chainer / chainer-pytorch-migration

Chainer/PyTorch Migration Library

Home Page:

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Training LinkAsTorchModel-wrapped Chainer model using PyTorch + Ignite?

msakai opened this issue · comments

Migration scenario says:

It might be easier to port in this order:

  1. Training script (optimizer / updater / evaluator / ...)
    • In order to use PyTorch optimizer to train a Chainer model, you will need cpm.LinkAsTorchModel.
  2. Dataset / preprocessing
    • Dataset is in general compatible between Chainer and PyTorch. This part can be delayed but also should be easy to do.
  3. Model
    • See the mapping of functions/modules below in this document.

But, in 1, cpm.LinkAsTorchModel does not override Module.forward, so that wrapped module cannot be trained using PyTorch + Ignite directly (i.e. If it is given to ignite.engine.create_supervised_evaluator and is invoked, then NotImplementedError is raised).

Is there a supposed way to train wrapped Chainer model using PyTorch (+ Ignite) in phase 1?

Can you provide a minimal example to reproduce this? thanks!

I originally was trying to port, but I made smaller example here

$ python
Device: @numpy
# unit: 1000
# Minibatch-size: 100
# epoch: 20

Traceback (most recent call last):
  File "", line 179, in <module>
  File "", line 175, in main, max_epochs=args.epoch)
  File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/ignite/engine/", line 446, in run
  File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/ignite/engine/", line 410, in _handle_exception
    raise e
  File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/ignite/engine/", line 433, in run
    hours, mins, secs = self._run_once_on_dataset()
  File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/ignite/engine/", line 399, in _run_once_on_dataset
  File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/ignite/engine/", line 410, in _handle_exception
    raise e
  File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/ignite/engine/", line 391, in _run_once_on_dataset
    self.state.output = self._process_function(self, batch)
  File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/ignite/engine/", line 49, in _update
    y_pred = model(x)
  File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/torch/nn/modules/", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/torch/nn/modules/", line 97, in forward
    raise NotImplementedError

My attempt is just for familiarizing myself with PyTorch and porting from Chainer.
So it's not a serious issue and I'm not in hurry.

Thank you very much for the prompt response, I believe it is not difficult to fix it so I will take a look asap

The main problem here is that the computational graphs are being mixed and ignite tries to feed the chainer model with torch.Tensor objects and also backpropagate a loss calculated in pytorch.
We need to add some custom wrapper over ignite to solve this issue, as the training loop for a chainer and a torch model is different. I am working on the PR now.

#6 and #7 are used to solve this issue.

There are some oddities in the code and some custom functions need to be added to.

The main issue is ignite internals and DataLoader assuming a format for the batch, training different from chainer. Ignite accuracy does not work as it uses torch.Tensor notation.
Right now I solve some of this issues manually in the code above, but we need to discuss a clean solution for this that does not impact performance.

Thanks for the PRs and the working code which is very helpful for me.

I think Ignite accuracy can be used if we supply output_transform.

def output_transform(args):
    y_pred, y = args
    return cpm.astensor(y_pred.array), cpm.astensor(y)

evaluator = ignite.engine.create_supervised_evaluator(
        'accuracy': ignite.metrics.Accuracy(output_transform),
        'loss': ignite.metrics.Loss(loss_fn),

Thanks for the pointer!