Atcold / pytorch-CortexNet

PyTorch implementation of the CortexNet predictive model

Home Page:http://tinyurl.com/CortexNet/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Does the prednet accept batches?

Sahaj09 opened this issue · comments

Does the prednet accept batches during train/test?

The input is given as-

input_sequence = Variable(torch.rand(T, 1, 1, 4 * 2 ** L, 6 * 2 ** L))

I assumed (time-step, batch size, channels, length, breadth) is the input format. Am I wrong?

Both MatchNet and TempoNet expect one element at the time.

for t in range(0, min(args.big_t, x.size(0)) - 1):
    ce_loss, mse_loss, state, x_hat_data = compute_loss(x[t], x[t + 1], y[t], state)
def compute_loss(x_, next_x, y_, state_):
    (x_hat, state_), (_, idx) = model(V(x_), state_)
    ...
    return ce_loss_, mse_loss_, state_, x_hat.data