salesforce / matchbox

Write PyTorch code at the level of individual examples, then run it efficiently on minibatches.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

inner() missing 1 required positional argument: 'batch2'

vlasenkov opened this issue · comments

Hi! I'm trying to run a bidirectional GRU with masking. This fails with an error. Is it a bug?

The full script:

import torch
from torch import nn
import matchbox
from matchbox import MaskedBatch
import numpy as np

class BiGRU(nn.Module):
    
    def __init__(self, in_size, out_size):
        super().__init__()
        self.fcell = nn.GRUCell(in_size, out_size)
        self.rcell = nn.GRUCell(in_size, out_size)

    @matchbox.batch
    def forward(self, x, h0=None):
        hf = x.batch_zeros(self.fcell.hidden_size) if h0 is None else h0
        for xt in x.unbind(1):
            hf = self.fcell(xt, hf)
        hr = x.batch_zeros(self.rcell.hidden_size) if h0 is None else h0
        for xt in reversed(x.unbind(1)):
            hr = self.rcell(xt, hr)
        return hf, hr

model = BiGRU(4, 5)

x = MaskedBatch(
    data=torch.rand((2, 3, 4)),
    dims=(True, False),
    mask=torch.tensor([
        [1, 1, 1],
        [1, 1, 0],
    ], dtype=torch.float32)[:, :, np.newaxis]
) 

model(x)

Traceback:

Traceback (most recent call last):
  File "test.py", line 49, in <module>
    model(x)
  File ".../dl/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/tmp/tmpnu655rxr/matchbox_618c.py", line 10, in forward
    matchbox.MaskedBatch, matchbox.TENSOR_TYPE)) else self.fcell(xt, hf
  File ".../dl/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File ".../dl/lib/python3.6/site-packages/torch/nn/modules/rnn.py", line 763, in forward
    self.bias_ih, self.bias_hh,
  File ".../dl/lib/python3.6/site-packages/torch/nn/_functions/rnn.py", line 64, in GRUCell
    hy = newgate + inputgate * (hidden - newgate)
  File ".../matchbox/matchbox/functional/elementwise.py", line85, in inner
    return replacement(self, other)
  File ".../matchbox/matchbox/functional/elementwise.py", line90, in <lambda>
    TENSOR_TYPE.__sub__ = _inject_arith(TENSOR_TYPE.__sub__, lambda a, b: -b + a)
TypeError: inner() missing 1 required positional argument: 'batch2'

Cloned the repo from master. Commit hash: d8ec789.

shall we treat neg as _elementwise_unary operator?