rikdz / GraphWriter

Code for "Text Generation from Knowledge Graphs with Graph Transformers"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

some question about list_encode?

hustcxx opened this issue · comments

def forward(self,batch,pad=True):
batch,phlens,batch_lens = batch
batch_lens = tuple(batch_lens.tolist())#list-to-tuple
_,enc = self.seqenc((batch,phlens))
enc = enc[:,2:]
enc = torch.cat([enc[:,i] for i in range(enc.size(1))],1)
m = max(batch_lens)
encs = [self.pad(x,m) for x in enc.split(batch_lens)]
out = torch.stack(encs,0)
return out
i don't understand why use enc = enc[:,2:] and enc = torch.cat([enc[:,i] for i in range(enc.size(1))],1)??? anybody can explain to me?thanks .

sorry this code is so messy. I think the point of these two lines is to get the final hidden state from the top layer of the encoder BiLSTM and cat both directions together into a single vector. What they seem to actually do is concatenate the hidden states for all but the first layer of the BiLSTM. This results in the same behavior when layers==2, as is used in this codebase.

thanks,A few days ago,after reading the layers parameter I realized that.