dido1998 / Recurrent-Independent-Mechanisms

Implementation of the paper Recurrent Independent Mechanisms (https://arxiv.org/pdf/1909.10893.pdf)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

GroupLinearLayer should add "device" parameter

ildefons opened this issue · comments

Hi RIM dev team,

Code fails when device = 'cuda'
It can be easily solved adding an extra parameter "device" to all "Group" classes.

Thank you for the great RIM implementation,

HI @ildefons, could you share the error message with device='cuda'?

Error message:

RuntimeError Traceback (most recent call last)
1 for x in xs:
2 print(1)
----> 3 hs, cs = rim_model(x, hs, cs)

~\anaconda3\envs\eg2\lib\site-packages\torch\nn\modules\module.py in call(self, *input, **kwargs)
548 result = self._slow_forward(*input, **kwargs)
549 else:
--> 550 result = self.forward(*input, **kwargs)
551 for hook in self._forward_hooks.values():
552 hook_result = hook(self, input, result)

~\OneDrive\Documentos\YK\eg\Recurrent-Independent-Mechanisms\RIM.py in forward(self, x, hs, cs)
250 # Compute input attention
--> 251 inputs, mask = self.input_attention_mask(x, hs)
252 h_old = hs * 1.0
253 if cs is not None:

~\OneDrive\Documentos\YK\eg\Recurrent-Independent-Mechanisms\RIM.py in input_attention_mask(self, x, h)
177 key_layer = self.key(x)
178 value_layer = self.value(x)
--> 179 query_layer = self.query(h)
181 key_layer = self.transpose_for_scores(key_layer, self.num_input_heads, self.input_key_size)

~\anaconda3\envs\eg2\lib\site-packages\torch\nn\modules\module.py in call(self, *input, **kwargs)
548 result = self._slow_forward(*input, **kwargs)
549 else:
--> 550 result = self.forward(*input, **kwargs)
551 for hook in self._forward_hooks.values():
552 hook_result = hook(self, input, result)

~\OneDrive\Documentos\YK\eg\Recurrent-Independent-Mechanisms\RIM.py in forward(self, x)
31 x = x.permute(1,0,2)
---> 33 x = torch.bmm(x,self.w)
34 return x.permute(1,0,2)

RuntimeError: Expected object of device type cuda but got device type cpu for argument #2 'mat2' in call to _th_bmm