torch / nngraph

Graph Computation for nn

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Error when switching between float and cuda

bamos opened this issue · comments

Hi, consider the following example to illustrate this issue. It creates a gModule
and converts it to CUDA and back to a float.
After converting back to a float, I think some internal state isn't correctly updated
and I'm getting the error below.

I'm using the latest master branch of:

  • torch f62a95d3184fd730acdaa4754647b338d7686301
  • cutorch a7147d00e61a5e182a277995f5d1e99ec3bdf0f8
  • nn bc056eeb09f83aaba354d44b985b1819b6b6ee4a
  • cunn 3827fcd820d5d0d90cb37a443c403b47009cb7d4
  • nngraph d0c239b
require 'cutorch'
require 'nn'
require 'cunn'
require 'nngraph'
torch.manualSeed(1)
input = nn.Identity()()
L1 = nn.ReLU()(nn.Linear(3, 1)(input))
net = nn.Sequential()
net:add(L1)
g = nn.gModule({input}, {L1})
x = torch.randn(3)
g:forward(x)
g:cuda()
g:forward(x:cuda())
g:float()
g:forward(x)

Output

th> g:forward(x)
 0.1432
[torch.DoubleTensor of size 1]

                                                                      [0.0001s]
th> g:cuda()
nn.gModule
                                                                      [0.0596s]
th> g:forward(x:cuda())
 0.1432
[torch.CudaTensor of size 1]

                                                                      [0.0003s]
th> g:float()
nn.gModule
                                                                      [0.0004s]
th> g:forward(x)
/home/bamos/torch/install/share/lua/5.1/nn/Linear.lua:51: expected arguments: *FloatTensor~1D* [FloatTensor~1D] [float] FloatTensor~2D FloatTensor~1D | *FloatTensor~1D* float [FloatTensor~1D] float FloatTensor~2D FloatTensor~1D
stack traceback:
        [C]: in function 'addmv'
        /home/bamos/torch/install/share/lua/5.1/nn/Linear.lua:51: in function 'func'
        /home/bamos/torch/install/share/lua/5.1/nngraph/gmodule.lua:311: in function 'neteval'
        /home/bamos/torch/install/share/lua/5.1/nngraph/gmodule.lua:346: in function 'forward'
        [string "_RESULT={g:forward(x)}"]:1: in main chunk
        [C]: in function 'xpcall'
        /home/bamos/torch/install/share/lua/5.1/trepl/init.lua:630: in function 'repl'
        ...amos/torch/install/lib/luarocks/rocks/trepl/scm-1/bin/th:185: in main chunk
        [C]: at 0x00406670

@bamos the problem is that your input is double, whereas the network is float. Cast x to float and it will work just fine.

Thanks @fmassa! Sorry for the noise...

no problem ! :)