ibm-aur-nlp / PubTabNet

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to trigger cell decoder based on output of structure decoder?

nishchay47b opened this issue · comments

Hi,
Can you give some implementation reference for this line of the paper
"When the structure decoder recognizes a new cell, the cell decoder is triggered and uses the hidden state of the structure decoder to compute the attention for recognizing the content of the new cell."
How does this type of training code is implemented, after the initial pre-training does both the decoder gets trained simultaneously or the control is shifted from structure decoder to cell decoder and somehow returned when the cell text is recognized.
Thanks.

commented

Yes, the two decoders are jointly trained after the initial pre-training.

The implementation is to record the hidden states of the structure decoder when a token <td> or > is generated (these two tokens indicates the beginning of a new cell). Then the recorded the hidden states are sent to the cell decoder to generate content.

So in one forward pass;
First, the structure decoder generates tokens, and the hidden states for each <td> or > are stored and the loss for structure decoder is calculated.
Then all these hidden states along with encoder output are passed to cell decoder for attention calculation next cell decoder generates outputs and finally, the loss for cell decoder is calculated.

commented

So in one forward pass;
First, the structure decoder generates tokens, and the hidden states for each <td> or > are stored and the loss for structure decoder is calculated.
Then all these hidden states along with encoder output are passed to cell decoder for attention calculation next cell decoder generates outputs and finally, the loss for cell decoder is calculated.

Correct!

Thanks for your quick response. I understand you guys are working on legalities to open-source the model and source code. But can you tell which framework you used, also can you release the implementation of your WYGIWYS experiment, especially the weights and did you train WYGIWYS with the combined HTML (structure and cell tokens)?

commented

Thanks for your quick response. I understand you guys are working on legalities to open-source the model and source code. But can you tell which framework you used, also can you release the implementation of your WYGIWYS experiment, especially the weights and did you train WYGIWYS with the combined HTML (structure and cell tokens)?

I used pytorch and the implementation is based on this nice tutorial on image captioning using attention-based encoder-decoder. I implemented WYGIWYS following this tutorial too, adding RNN to the encoder and using TBPTT to train the decoder. And yes, when training WYGIWYS, the output is the combined HTML.

Thanks a lot!
It would be really helpful if you can give any reference link around the idea, that based on the output of decoder how do you select only a certain hidden state.
Based on this thread and paper what I have understood is that this model will have two decoders(d1 and d2) and one encoder. One decoder(d1) gets input only from encoder while another one(d2) will get input from encoder and other decoder(d1). d2 must get hidden states from d1 only when d1 makes a particular type of prediction. Both decoders have a different set of vocabulary. Say d1 has “a, b, c, d” and d2 has “P, Q, R, S” and we want to pass a hidden state from d1 to d2 only when d1 predicts “b”.
I am getting confused because when generating captions, we apply something like a beam search or greedy method to choose the token, how are we selecting only specific tokens in training step. Another problem is the hidden state will be (batch_size, decoder_dim) so are you selecting say the ith index of batch which generated that particular token. In DecoderWithAttention

`

        batch_size_t = sum([l > t for l in decode_lengths])

        attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t],
                                                            h[:batch_size_t])

        gate = self.sigmoid(self.f_beta(h[:batch_size_t]))  # gating scalar, (batch_size_t, encoder_dim)

        attention_weighted_encoding = gate * attention_weighted_encoding

        h, c = self.decode_step(
            torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
            (h[:batch_size_t], c[:batch_size_t]))  # (batch_size_t, decoder_dim)

        preds = self.fc(self.dropout(h))  # (batch_size_t, vocab_size)

        predictions[:batch_size_t, t, :] = preds

        alphas[:batch_size_t, t, :] = alpha

`

are you selecting the hidden state based on the highest score in "preds", because it has dimension (batch_size,vocab_size) If the vocab_index coresponding to <td> or > has highest score then you choose that element in batch and then its coreponding hidden state?

Thanks a ton again, any reference on the connection between two decoders will be really helpful.

commented

At training time, we use teacher forcing. So we use the ground truth token at t_k to predict t_k+1. For the structure decoder, when the input ground truth token at t_k is <td> or >, the hidden state h_k+1 is saved. The cell decoder takes the output of the encoder and the saved hidden states of the structure decoder, to predict the content of each cell.

At inference time, I use beam search to select the tokens. Say the structure decoder predicts <td> or > at t_k, which becomes the input for t_k+1. In this case, the hidden state h_k+1 is saved and sent to the cell decoder.

Thanks, this made a lot of things clear. Now while implementing this idea, I am facing another problem. In every batch, each sequence of structure tokens will have a different number of <td> or > tokens because every image has a different number of cells. So after storing the hidden state, does cell decoder trains per image, I mean is there a notion of a minimum number of cells that essentially forms a batch of cells inside a batch of images and cell decoder trains on this batch of cells one image at a time.
Really sorry if the language is confusing.

commented

Yes. cell decoder is trained per image. For batch of images, I run the forward pass of cell decoder on each image, then sum the their losses. For each image, I consider all the cells in an image as a "batch".

I think I have managed to create the architecture but I am still having problems in training due to limited hardware. Given that a cell decoder is trained per image In order to apply attention in cell decoder are image features replicated the number of times as the number of cells?
For example, if you have 50 cells in an image, then your batch for cell decoder becomes 50, you'll have 50 hidden states from str decoder but this is one image so to apply attention using these hidden state a tensor of 50,448,448,3 is created by replicating an image 50 times which is further passed through the encoder.
Not sure if I am doing this right, any comment on this because this part is taking a lot of memory.

commented

Glad you are close. I did not replicate the image or the encoder output. In the attention of cell decoder, encoder output and hidden states first run through a linear layer to convert them to the same dimension. Then they can be added by casting.

I think I have managed to create the architecture but I am still having problems in training due to limited hardware. Given that a cell decoder is trained per image In order to apply attention in cell decoder are image features replicated the number of times as the number of cells?
For example, if you have 50 cells in an image, then your batch for cell decoder becomes 50, you'll have 50 hidden states from str decoder but this is one image so to apply attention using these hidden state a tensor of 50,448,448,3 is created by replicating an image 50 times which is further passed through the encoder.
Not sure if I am doing this right, any comment on this because this part is taking a lot of memory.

During pre-training of structure decoder and encoder, did you fine-tuned fifth child onwards as discussed here or was it something different.
I am implementing in Keras/Tensorflow and missed an important detail that in PyTorch there is a notion of the child which contains the blocks vs layers in Keras which is an individual layer and weights got all messed up and I am getting only </td> in the output.

commented

The encoder (ResNet) is trained from scratch, and all layers are trained. I did not load the pre-trained ResNet on ImageNet, as I think it may not be very helpful for table images. I had sufficient data to just train it from scratch.

@nishchay47b Wanted to check if you were able to complete the implementation. If yes, how is the result? I have also started to implement the EDD architecture. I'm referring to https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning repository.

Can you share how you implemeted Dual Decoder?

I have used the tutorial you mentioned and this example to implement EDD in TF/Keras. However, I am not able to fully train and evaluate due to hardware limitations that is why I cannot confidently comment if my implementation will work or not. The structure part is partially working while the cell part is yet to be trained. @zhxgj has graciously described fine details here and the implementation section of the paper is a great resource.

@nishchay47b Thank you for the details. I'm using 11GB GPU, is that enough to train the model with less batch size.