torch / nngraph

Graph Computation for nn

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Unexpected behavior in saving network as :float()

mbchang opened this issue · comments

Please see the issue posted in torch/torch7#711. I had accidentally submitted an issue for torch/nngraph there, and I'm not sure how to remove it. Thank you!

before loading the model, just execute the line:
require 'cunn'

@soumith I do realize that if my computer has gpu capabilites, executing require cunn lets me load the checkpoint, as mentioned in my post. However, I intend to load the checkpoint on a computer that does not have gpu capabilities, for which I can't install cunn in the first place. That point aside, I had expected that casting the network to Float type would remove any need for cuda dependencies.

@soumith There is indeed a problem here.
I think the guilty guy is in this line. When we call type, the buffers are all cleared, but the references of the old forward inputs are still there. The same might apply for the backward nodes.

Two workarounds for the moment: after converting to float and before saving, either do

  • :clearState(), or
  • run a forward/backward pass using float data.

Another consequence is that the model is almost double the original size.

I think this should be addressed in nngraph though.

@fmassa I dont think that holds for type. What you are saying actually might be affecting clearState though.
To check what you said, I added this assertion to tests, and it passes.
08d0b5d

@mbchang reproduced your issue. I am looking into it.

Thanks @soumith, that would be a great help!

@soumith tests passes because forward is runned after model conversion. Without the forward it will fail I think.
but my workaround didn't work for networks containing cudnn modules which were converted to nn, I had to recreate the module and copy the parameters by hand

oohh i see, ok makes sense. looking into it.

I just realized that the line that my comment was pointing is off. I was referring to this line, which I'll copy here to avoid other ambiguities:

child.data.input[mapindex] = x

@mbchang @fmassa fixed it via #126 . Reinstall nngraph and it should be fixed.