ant-research / StructuredLM_RTDT

A library for building hierarchical text representation and corresponding downstream applications.

Repository from Github https://github.comant-research/StructuredLM_RTDTRepository from Github https://github.comant-research/StructuredLM_RTDT

an potential issue found for the nn.MultiheadAttention module setup

frankaging opened this issue · comments

Hi Authors!

Thanks for sharing this repo, I enjoyed when reading your paper, and I am working on a related project. As I am going through the code, I found one potential issue with the current setup. I will (1) explain the issue, and (2) provide a simple test case that I ran on my end. Please help with verifying.

Issue:

  • nn.MultiheadAttention module inside the BinaryEncoder module is set with batch_first=True, however it seems like we are passing in Q, K, V matrics without the first dimension being the batch dimension.

Code Analysis:
In r2d2.py, it is calling the encoder here, as the following

        tasks_embedding = self.embedding(task_ids)  # (?, 2, dim)
        input_embedding = torch.cat([tasks_embedding, tensor_batch], dim=1)  # (?, 4, dim)
        outputs = self.tree_encoder(input_embedding.transpose(0, 1)).transpose(0, 1)  # (? * batch_size, 4, dim)

We can see that input_embedding is definitely with the first dimension being the batch_size as it concat with the embeddings from the nn.embedding module. Before we call self.tree_encoder, .transpose(0, 1) makes the the second dimension of the input being the batch_size instead. Specifically, the first dimension, in this case, is always 4.

Testing Done:
I simply add some logs inside TreeEncoderLayer as,

    def forward(self, src, src_mask=None, pos_ids=None):
        """
        :param src: concatenation of task embeddings and representation for left and right.
                    src shape: (task_embeddings + left + right, batch_size, dim)
        :param src_mask:
        :param pos_ids:
        :return:
        """
        if len(pos_ids.shape) == 1:
            sz = src.shape[0]  # sz: batch_size
            pos_ids = pos_ids.unsqueeze(0).expand(sz, -1)  # (3, batch_size)
        position_embedding = self.position_embedding(pos_ids)
        print("pre: ", src.shape)
        print("pos_emb: ", position_embedding.shape)
        output = self.self_attn(src + position_embedding, src + position_embedding, src, attn_mask=src_mask)
        src2 = output[0]
        attn_weights = output[1]
        print("attn_w: ", attn_weights.shape)
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        print("post: ", src.shape)
        return src

And this is what I get,

pre:  torch.Size([4, 8, 768])
pos_emb:  torch.Size([4, 8, 768])
attn_w:  torch.Size([4, 8, 8])
post:  torch.Size([4, 8, 768])

Summary:
It seems like for r2d2.py, the self-attention is not on those 4 tokens (2 special prefix + left and right children embedding), but it is on the full collection of candidates with their children.

I saw this issue is not presented in r2d2_cuda.py as,

            outputs = self.tree_encoder(
                input_embedding)  # (? * batch_size, 4, dim)

This is great. I have not checked the rest of the code for r2d2_cuda.py though. With this, I am wondering are the numbers from either of your papers need to be updated with this potential issue? Either way, I am not blocked by this potential issue, and I was inspired quite a lot by your codebase. Thanks!

Many thanks for your feedback. Actually, r2d2.py is used in the first paper accepted by ACL 2021, and r2d2_cuda.py is used in the paper accepted by EMNLP. As r2d2.py is not used in our latest work, so we didn't do a regression test which may cause some discrepancies. If you are interested in the first work, I've made a tag for it: https://github.com/alipay/StructuredLM_RTDT/tree/r2d2, which is the original code for the first paper. The current branch actually only supports the cuda version(r2d2_cuda.py). Since r2d2.py actually is legacy code, we'll consider fixing the discrepancy or removing it directly. But the numbers in the paper are running in the correct version. If you have a Cuda environment, I suggest you use the latest version(Fast-R2D2), which is almost 30 folds faster than R2D2, with better downstream tasks performance.

Thanks for your quick response! Closing the issue as this is not found in the r2d2 repo.

I've checked in a bug-fixed version of r2d2 in the latest branch. We will release a model pretrained on wiki103 of fast-r2d2 soon, hope that will be helpful to your work :)