torch / nngraph

Graph Computation for nn

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

"Expecting only one start" error when using JoinTable

shaayaansayed opened this issue · comments

There seems to be a bug in the way I'm writing my LSTM architecture.

local LSTM = {}
function LSTM.create(input_size, input2_size, output_size, rnn_size)
    local inputs = {}
    local outputs = {}

    table.insert(inputs, nn.Identity()())
    table.insert(inputs, nn.Identity()())
    for L=1,2 do
        table.insert(inputs, nn.Identity()())
        table.insert(inputs, nn.Identity()())
    end

    local x, x_p, x2
    for L=1,2 do
        if L == 1 then
            x_p = inputs[1]
            input_size_L = input_size
        else 
            x = outputs[(L-2)*2]
            x2 = nn.LookupTable(input2_size+1, rnn_size)(inputs[2])
            x_p = nn.JoinTable(2)({x, x2})
            input_size_L = rnn_size*2
        end

        prev_c = inputs[L*2+1]
        prev_h = inputs[L*2+2]

        local i2h = nn.Linear(input_size_L, 4 * rnn_size)(x_p)
        local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h)
        local all_input_sums = nn.CAddTable()({i2h, h2h})

        local reshaped = nn.Reshape(4, rnn_size)(all_input_sums)
        local n1, n2, n3, n4 = nn.SplitTable(2)(reshaped):split(4)

        local in_gate = nn.Sigmoid()(n1)
        local forget_gate = nn.Sigmoid()(n2)
        local out_gate = nn.Sigmoid()(n3)

        local in_transform = nn.Tanh()(n4)

        local next_c           = nn.CAddTable()({
            nn.CMulTable()({forget_gate, prev_c}),
            nn.CMulTable()({in_gate,     in_transform})
          })

        local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)})

            table.insert(outputs, next_c)
            table.insert(outputs, next_h)
    end

    local last_h = outputs[#outputs]
    local proj = nn.Linear(rnn_size, output_size)(last_h)
    local logsoft = nn.LogSoftMax()(proj)
    table.insert(outputs, logsoft)

    return nn.gModule(inputs, outputs)
end

return LSTM

The problem seems to be when I try to use nn.JoinTable. What I'm trying to do here is send inputs[2] through a lookuptable and concatenate that (column wise) with the hidden state output of the previous layer.

The hidden state output will be: batchSize x rnnSize. The lookuptable output will be batchSize x rnnSize. If I concatenate these tensors, the new tensor will be batchSize x 2*rnnSize.

Regardless, I'm getting a "expecting only one start" issue, which probably means I'm doing something fundamentally wrong.

Thanks!

commented

Hi shaayaansayed, what was the problem here? I'm getting the same error and can't figure out what throws this error from the source code.

I don't exactly remember how I resolved the issue, but looking at my old code make sure your inputs to join are valid along with being the same dimensions and what not.

One of my inputs to JoinTable in the code was:

x = outputs[(L-2)*2] which at L=2 is outputs[0], which is an invalid input, and I think that was the error.