rikdz / GraphWriter

Code for "Text Generation from Knowledge Graphs with Graph Transformers"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

multiple .cuda() usage might cause inconsistent device uses

rhythmswing opened this issue · comments

Hi,

I've noticed that in various .py files under /models/, such as last_graph.py and attention.py, in some modules, a new tensor (mostly masks) is created and is called .cuda(). Would it cause any device inconsistency issue?

For example, I might want to specify a non-default gpu device or even cpu in the input argument.
In the 246th line of attention.py, attention might be in cuda:1, while torch.sqrt(self._key_dim) is in cuda:0, raising an error.

Would it be better to use .to(attention.get_device()), if attention.get_device() > -1 (when indeed gpu is used)?

This is a good suggestion, thanks! I will update this soon.

This is a good suggestion, thanks! I will update this soon.

Actually I've fixed it in my local code, mind if I help?

Can you make a pull request and I will review it

Let me try that, thanks