chainer / chainer-pytorch-migration

Chainer/PyTorch Migration Library

Home Page:https://chainer.github.io/migration-guide/#h.9wc1iaeyqb2c

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Training LinkAsTorchModel-wrapped Chainer model using PyTorch + Ignite?

msakai opened this issue · comments

Migration scenario https://chainer.github.io/migration-guide/#h.y7ovx8en8cp7 says:

It might be easier to port in this order:

  1. Training script (optimizer / updater / evaluator / ...)
    • In order to use PyTorch optimizer to train a Chainer model, you will need cpm.LinkAsTorchModel.
  2. Dataset / preprocessing
    • Dataset is in general compatible between Chainer and PyTorch. This part can be delayed but also should be easy to do.
  3. Model
    • See the mapping of functions/modules below in this document.

But, in 1, cpm.LinkAsTorchModel does not override Module.forward, so that wrapped module cannot be trained using PyTorch + Ignite directly (i.e. If it is given to ignite.engine.create_supervised_evaluator and Engine.run is invoked, then NotImplementedError is raised).

Is there a supposed way to train wrapped Chainer model using PyTorch (+ Ignite) in phase 1?

Can you provide a minimal example to reproduce this? thanks!

I originally was trying to port https://github.com/pfnet-research/chainer-formulanet/, but I made smaller example here https://gist.github.com/msakai/eb3cdeb3d39fa8393397bc562210edd7.

$ python train_mnist_pytorch.py
Device: @numpy
# unit: 1000
# Minibatch-size: 100
# epoch: 20

Traceback (most recent call last):
  File "train_mnist_pytorch.py", line 179, in <module>
    main()
  File "train_mnist_pytorch.py", line 175, in main
    trainer.run(train_loader, max_epochs=args.epoch)
  File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/ignite/engine/engine.py", line 446, in run
    self._handle_exception(e)
  File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/ignite/engine/engine.py", line 410, in _handle_exception
    raise e
  File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/ignite/engine/engine.py", line 433, in run
    hours, mins, secs = self._run_once_on_dataset()
  File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/ignite/engine/engine.py", line 399, in _run_once_on_dataset
    self._handle_exception(e)
  File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/ignite/engine/engine.py", line 410, in _handle_exception
    raise e
  File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/ignite/engine/engine.py", line 391, in _run_once_on_dataset
    self.state.output = self._process_function(self, batch)
  File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/ignite/engine/__init__.py", line 49, in _update
    y_pred = model(x)
  File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/sakai/.pyenv/versions/anaconda3-2019.07/lib/python3.7/site-packages/torch/nn/modules/module.py", line 97, in forward
    raise NotImplementedError
NotImplementedError

My attempt is just for familiarizing myself with PyTorch and porting from Chainer.
So it's not a serious issue and I'm not in hurry.

Thank you very much for the prompt response, I believe it is not difficult to fix it so I will take a look asap

The main problem here is that the computational graphs are being mixed and ignite tries to feed the chainer model with torch.Tensor objects and also backpropagate a loss calculated in pytorch.
We need to add some custom wrapper over ignite to solve this issue, as the training loop for a chainer and a torch model is different. I am working on the PR now.

#6 and #7 are used to solve this issue.

There are some oddities in the code and some custom functions need to be added to.
https://gist.github.com/emcastillo/8be64990940ce23147cc08584bea985b

The main issue is ignite internals and DataLoader assuming a format for the batch, training different from chainer. Ignite accuracy does not work as it uses torch.Tensor notation.
Right now I solve some of this issues manually in the code above, but we need to discuss a clean solution for this that does not impact performance.

Thanks for the PRs and the working code which is very helpful for me.

I think Ignite accuracy can be used if we supply output_transform.

def output_transform(args):
    y_pred, y = args
    return cpm.astensor(y_pred.array), cpm.astensor(y)

evaluator = ignite.engine.create_supervised_evaluator(
    torched_model,
    metrics={
        'accuracy': ignite.metrics.Accuracy(output_transform),
        'loss': ignite.metrics.Loss(loss_fn),
    },
    prepare_batch=prepare_batch,
    device=torch_device)

Thanks for the pointer!