inner() missing 1 required positional argument: 'batch2'
vlasenkov opened this issue · comments
Hi! I'm trying to run a bidirectional GRU with masking. This fails with an error. Is it a bug?
The full script:
import torch
from torch import nn
import matchbox
from matchbox import MaskedBatch
import numpy as np
class BiGRU(nn.Module):
def __init__(self, in_size, out_size):
super().__init__()
self.fcell = nn.GRUCell(in_size, out_size)
self.rcell = nn.GRUCell(in_size, out_size)
@matchbox.batch
def forward(self, x, h0=None):
hf = x.batch_zeros(self.fcell.hidden_size) if h0 is None else h0
for xt in x.unbind(1):
hf = self.fcell(xt, hf)
hr = x.batch_zeros(self.rcell.hidden_size) if h0 is None else h0
for xt in reversed(x.unbind(1)):
hr = self.rcell(xt, hr)
return hf, hr
model = BiGRU(4, 5)
x = MaskedBatch(
data=torch.rand((2, 3, 4)),
dims=(True, False),
mask=torch.tensor([
[1, 1, 1],
[1, 1, 0],
], dtype=torch.float32)[:, :, np.newaxis]
)
model(x)
Traceback:
Traceback (most recent call last):
File "test.py", line 49, in <module>
model(x)
File ".../dl/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/tmp/tmpnu655rxr/matchbox_618c.py", line 10, in forward
matchbox.MaskedBatch, matchbox.TENSOR_TYPE)) else self.fcell(xt, hf
File ".../dl/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File ".../dl/lib/python3.6/site-packages/torch/nn/modules/rnn.py", line 763, in forward
self.bias_ih, self.bias_hh,
File ".../dl/lib/python3.6/site-packages/torch/nn/_functions/rnn.py", line 64, in GRUCell
hy = newgate + inputgate * (hidden - newgate)
File ".../matchbox/matchbox/functional/elementwise.py", line85, in inner
return replacement(self, other)
File ".../matchbox/matchbox/functional/elementwise.py", line90, in <lambda>
TENSOR_TYPE.__sub__ = _inject_arith(TENSOR_TYPE.__sub__, lambda a, b: -b + a)
TypeError: inner() missing 1 required positional argument: 'batch2'
Cloned the repo from master. Commit hash: d8ec789.
shall we treat neg as _elementwise_unary
operator?