torch / nngraph

Graph Computation for nn

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Error for single input gModule

tudor-berariu opened this issue · comments

I am trying to write a nn.gModule that works for an arbitrary number of inputs, but I get an error when there is only one. The following code works fine with inputsNo > 1, but fails for inputsNo = 1.

require("nn")
require("nngraph")

local` inputsNo = 1        -- it works for inputsNo >=2, but not for inputsNo = 1

local inputs, linInputs = {}, {}

for i = 1, inputsNo do
   inputs[i] = nn.Identity()()
   linInputs[i] = nn.Linear(i, i)(inputs[i])
end

local output = nn.JoinTable(1,1)(linInputs)
local rnn = nn.gModule(inputs, {output})

local X = {}
for i = 1, inputsNo do  X[i] = torch.rand(i)  end

rnn:forward(X)

Applying nngraph.nest on the value sent to nn.JoinTable solved my problem.