wojzaremba / lstm

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

replicate(x_inp, batch_size)

skaae opened this issue · comments

I'm having trouble understanding replicate.

-- Stacks replicated, shifted versions of x_inp
-- into a single matrix of size x_inp:size(1) x batch_size.
local function replicate(x_inp, batch_size)
   local s = x_inp:size(1)
   local x = torch.zeros(torch.floor(s / batch_size), batch_size)
   for i = 1, batch_size do
     local start = torch.round((i - 1) * s / batch_size) + 1
     local finish = start + x:size(1) - 1
     x:sub(1, x:size(1), i, i):copy(x_inp:sub(start, finish))
   end
   return x
end

Fom the comment I expected that the output would be a matrix of size number_of_words by batch_size

But the output is torch.floor(s / batch_size)by batch_size

If i load the first 5 lines from ptb.train.txt x is:

x:view(14, 8) -- change view for printing
  1   2   3   4   5   6   7   8
  9  10  11  12  13  14  15  16
 17  18  19  20  21  22  23  24
 25  26  27  28  29  30  31  32
 33  34  35  36  37  38  39  28
 25  40  27  41  42  43  27  44
 33  45  46  47  25  48  27  28
 29  30  49  50  42  43  51  52
 53  54  55  56  36  37  38  43
 57  58  59  60  25  36  61  43
 62  63  64  65  66  67  68  69
 70  71  36  72  73  43  74  75
 76  36  47  43  77  78  65  79
 80  81  28  29  82  83  84  25
[torch.DoubleTensor of size 14x8]

And the output from from replicate is:

replicate(x, 20)
  1   7  12  18  23  29  35  28  43  46  29  51  56  58  61  66  36  75  78  28
  2   8  13  19  24  30  36  25  27  47  30  52  36  59  43  67  72  76  65  29
  3   9  14  20  25  31  37  40  44  25  49  53  37  60  62  68  73  36  79  82
  4  10  15  21  26  32  38  27  33  48  50  54  38  25  63  69  43  47  80  83
  5  11  16  22  27  33  39  41  45  27  42  55  43  36  64  70  74  43  81  84
[torch.DoubleTensor of size 5x20]

Why is every second column shifted one? e.g 5-7, 11-12, 16-18, 22-23 etc?

yup, it's off. Thx.

     local nBatches = torch.floor(data:size(1)/batchSize)                                         
     local x = torch.zeros(nBatches, batchSize)
     for i = 1, batchSize do
        local start = (i-1) * nBatches + 1        
        local finish = i * nBatches
        x:sub(1, nBatches, i, i):copy(data:sub(start, finish))
     end