"Expecting only one start" error when using JoinTable
shaayaansayed opened this issue · comments
There seems to be a bug in the way I'm writing my LSTM architecture.
local LSTM = {}
function LSTM.create(input_size, input2_size, output_size, rnn_size)
local inputs = {}
local outputs = {}
table.insert(inputs, nn.Identity()())
table.insert(inputs, nn.Identity()())
for L=1,2 do
table.insert(inputs, nn.Identity()())
table.insert(inputs, nn.Identity()())
end
local x, x_p, x2
for L=1,2 do
if L == 1 then
x_p = inputs[1]
input_size_L = input_size
else
x = outputs[(L-2)*2]
x2 = nn.LookupTable(input2_size+1, rnn_size)(inputs[2])
x_p = nn.JoinTable(2)({x, x2})
input_size_L = rnn_size*2
end
prev_c = inputs[L*2+1]
prev_h = inputs[L*2+2]
local i2h = nn.Linear(input_size_L, 4 * rnn_size)(x_p)
local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h)
local all_input_sums = nn.CAddTable()({i2h, h2h})
local reshaped = nn.Reshape(4, rnn_size)(all_input_sums)
local n1, n2, n3, n4 = nn.SplitTable(2)(reshaped):split(4)
local in_gate = nn.Sigmoid()(n1)
local forget_gate = nn.Sigmoid()(n2)
local out_gate = nn.Sigmoid()(n3)
local in_transform = nn.Tanh()(n4)
local next_c = nn.CAddTable()({
nn.CMulTable()({forget_gate, prev_c}),
nn.CMulTable()({in_gate, in_transform})
})
local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)})
table.insert(outputs, next_c)
table.insert(outputs, next_h)
end
local last_h = outputs[#outputs]
local proj = nn.Linear(rnn_size, output_size)(last_h)
local logsoft = nn.LogSoftMax()(proj)
table.insert(outputs, logsoft)
return nn.gModule(inputs, outputs)
end
return LSTM
The problem seems to be when I try to use nn.JoinTable. What I'm trying to do here is send inputs[2] through a lookuptable and concatenate that (column wise) with the hidden state output of the previous layer.
The hidden state output will be: batchSize x rnnSize. The lookuptable output will be batchSize x rnnSize. If I concatenate these tensors, the new tensor will be batchSize x 2*rnnSize.
Regardless, I'm getting a "expecting only one start" issue, which probably means I'm doing something fundamentally wrong.
Thanks!
Hi shaayaansayed, what was the problem here? I'm getting the same error and can't figure out what throws this error from the source code.
I don't exactly remember how I resolved the issue, but looking at my old code make sure your inputs to join are valid along with being the same dimensions and what not.
One of my inputs to JoinTable in the code was:
x = outputs[(L-2)*2]
which at L=2
is outputs[0]
, which is an invalid input, and I think that was the error.