torch / nngraph

Graph Computation for nn

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Wrong "Number of gradients do not match my graph" error

fidlej opened this issue · comments

I have two outputs.
And I pass in two gradOutputs.
The nngraph complains:

#outputs   =1   
#gradients =2
nngraph/gmodule.lua:251: Number of gradients do not match my graph

Example code to reproduce the error:

local in1 = nn.Sigmoid()()
local splitTable = nn.SplitTable(1)({in1})
local module = nn.gModule({in1}, {splitTable})

local input = torch.randn(2, 3)
local output = module:forward(input)
assert(#output == 2, "we have two outputs")
module:backward(input, {torch.randn(3), torch.randn(3)})

The graph module can not know if you want to pass a table of tensors into a single module or split that. To handle all these cases would be too much boundary handling.

Instead, there is split to properly handle this case. I think this solves the problem, right?

in1 = nn.Sigmoid()()
splitTable = nn.SplitTable(1)({in1})
o1,o2 = splitTable:split(2)
mod = nn.gModule({in1}, {o1,o2})

input = torch.randn(2, 3)
output = mod:forward(input)
assert(#output == 2, "we have two outputs")
mod:backward(input, {torch.randn(3), torch.randn(3)})

Good idea. I like the checking of the number of passed inputs and outputs.
I will explicitly mention the number of graph outputs.