an potential issue found for the nn.MultiheadAttention module setup
frankaging opened this issue · comments
Hi Authors!
Thanks for sharing this repo, I enjoyed when reading your paper, and I am working on a related project. As I am going through the code, I found one potential issue with the current setup. I will (1) explain the issue, and (2) provide a simple test case that I ran on my end. Please help with verifying.
Issue:
- nn.MultiheadAttention module inside the
BinaryEncoder
module is set withbatch_first=True
, however it seems like we are passing in Q, K, V matrics without the first dimension being the batch dimension.
Code Analysis:
In r2d2.py
, it is calling the encoder here, as the following
tasks_embedding = self.embedding(task_ids) # (?, 2, dim)
input_embedding = torch.cat([tasks_embedding, tensor_batch], dim=1) # (?, 4, dim)
outputs = self.tree_encoder(input_embedding.transpose(0, 1)).transpose(0, 1) # (? * batch_size, 4, dim)
We can see that input_embedding
is definitely with the first dimension being the batch_size
as it concat with the embeddings from the nn.embedding
module. Before we call self.tree_encoder
, .transpose(0, 1)
makes the the second dimension of the input being the batch_size
instead. Specifically, the first dimension, in this case, is always 4.
Testing Done:
I simply add some logs inside TreeEncoderLayer
as,
def forward(self, src, src_mask=None, pos_ids=None):
"""
:param src: concatenation of task embeddings and representation for left and right.
src shape: (task_embeddings + left + right, batch_size, dim)
:param src_mask:
:param pos_ids:
:return:
"""
if len(pos_ids.shape) == 1:
sz = src.shape[0] # sz: batch_size
pos_ids = pos_ids.unsqueeze(0).expand(sz, -1) # (3, batch_size)
position_embedding = self.position_embedding(pos_ids)
print("pre: ", src.shape)
print("pos_emb: ", position_embedding.shape)
output = self.self_attn(src + position_embedding, src + position_embedding, src, attn_mask=src_mask)
src2 = output[0]
attn_weights = output[1]
print("attn_w: ", attn_weights.shape)
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
print("post: ", src.shape)
return src
And this is what I get,
pre: torch.Size([4, 8, 768])
pos_emb: torch.Size([4, 8, 768])
attn_w: torch.Size([4, 8, 8])
post: torch.Size([4, 8, 768])
Summary:
It seems like for r2d2.py
, the self-attention is not on those 4 tokens (2 special prefix + left and right children embedding), but it is on the full collection of candidates with their children.
I saw this issue is not presented in r2d2_cuda.py
as,
outputs = self.tree_encoder(
input_embedding) # (? * batch_size, 4, dim)
This is great. I have not checked the rest of the code for r2d2_cuda.py
though. With this, I am wondering are the numbers from either of your papers need to be updated with this potential issue? Either way, I am not blocked by this potential issue, and I was inspired quite a lot by your codebase. Thanks!
Many thanks for your feedback. Actually, r2d2.py is used in the first paper accepted by ACL 2021, and r2d2_cuda.py is used in the paper accepted by EMNLP. As r2d2.py is not used in our latest work, so we didn't do a regression test which may cause some discrepancies. If you are interested in the first work, I've made a tag for it: https://github.com/alipay/StructuredLM_RTDT/tree/r2d2, which is the original code for the first paper. The current branch actually only supports the cuda version(r2d2_cuda.py). Since r2d2.py actually is legacy code, we'll consider fixing the discrepancy or removing it directly. But the numbers in the paper are running in the correct version. If you have a Cuda environment, I suggest you use the latest version(Fast-R2D2), which is almost 30 folds faster than R2D2, with better downstream tasks performance.
Thanks for your quick response! Closing the issue as this is not found in the r2d2 repo.
I've checked in a bug-fixed version of r2d2 in the latest branch. We will release a model pretrained on wiki103 of fast-r2d2 soon, hope that will be helpful to your work :)