How to extract attention weights for each sequence
Albert-Shuai opened this issue · comments
Hi:
I am wondering how to extract the attention weights for sequence in a sample. In this way we may rank the sequence based on their importance. Thanks! Supposed I train a model named model based on code in example_single_task_cnn.py
Hi! Sorry for the late response. You can find the attention weights as tensor "attention_weights" here:
DeepRC/deeprc/architectures.py
Line 417 in 108d08d
You would have to pass those tensors into the evaluation function somehow.
A quick-and-dirty hack would be to just set the attention weights as attribute. E.g. like this as modification of the part between lines 412- 434 :
# Compute representation per bag (N times shape (d_v,))
mb_emb_seqs_after_attention = []
mb_attention_weights_list = []
start_i = 0
for n_seqs in n_sequences_per_bag:
# Get sequence embedding h() for single bag (shape: (n_sequences_per_bag, d_v))
attention_weights = mb_attention_weights[start_i:start_i+n_seqs]
mb_attention_weights_list.append(attention_weights.detach())
# Get attention weights for single bag (shape: (n_sequences_per_bag, 1))
emb_seqs = mb_emb_seqs[start_i:start_i+n_seqs]
# Calculate attention activations (softmax over n_sequences_per_bag) (shape: (n_sequences_per_bag, 1))
attention_weights = torch.softmax(attention_weights, dim=0)
# Apply attention weights to sequence features (shape: (n_sequences_per_bag, d_v))
emb_seqs_after_attention = emb_seqs * attention_weights
# Compute weighted sum over sequence features after attention (format: (d_v,))
mb_emb_seqs_after_attention.append(emb_seqs_after_attention.sum(dim=0))
start_i += n_seqs
# Stack representations of bags (shape (N, d_v))
emb_seqs_after_attention = torch.stack(mb_emb_seqs_after_attention, dim=0)
# Calculate predictions (shape (N, n_outputs))
predictions = self.output_nn(emb_seqs_after_attention)
self.mb_attention_weights_list = mb_attention_weights_list
return predictions
Later you could access the attribute model.mb_attention_weights_list
in the evaluation function.
I hope that helps. Let me know if you need further help.
Best wishes, Michael