ml-jku / DeepRC

DeepRC: Immune repertoire classification with attention-based deep massive multiple instance learning

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to extract attention weights for each sequence

Albert-Shuai opened this issue · comments

Hi:

I am wondering how to extract the attention weights for sequence in a sample. In this way we may rank the sequence based on their importance. Thanks! Supposed I train a model named model based on code in example_single_task_cnn.py

Hi! Sorry for the late response. You can find the attention weights as tensor "attention_weights" here:

attention_weights = mb_attention_weights[start_i:start_i+n_seqs]

You would have to pass those tensors into the evaluation function somehow.

A quick-and-dirty hack would be to just set the attention weights as attribute. E.g. like this as modification of the part between lines 412- 434 :

        # Compute representation per bag (N times shape (d_v,))
        mb_emb_seqs_after_attention = []
        mb_attention_weights_list = []
        start_i = 0
        for n_seqs in n_sequences_per_bag:
            # Get sequence embedding h() for single bag (shape: (n_sequences_per_bag, d_v))
            attention_weights = mb_attention_weights[start_i:start_i+n_seqs]
            mb_attention_weights_list.append(attention_weights.detach())
            # Get attention weights for single bag (shape: (n_sequences_per_bag, 1))
            emb_seqs = mb_emb_seqs[start_i:start_i+n_seqs]
            # Calculate attention activations (softmax over n_sequences_per_bag) (shape: (n_sequences_per_bag, 1))
            attention_weights = torch.softmax(attention_weights, dim=0)
            # Apply attention weights to sequence features (shape: (n_sequences_per_bag, d_v))
            emb_seqs_after_attention = emb_seqs * attention_weights
            # Compute weighted sum over sequence features after attention (format: (d_v,))
            mb_emb_seqs_after_attention.append(emb_seqs_after_attention.sum(dim=0))
            start_i += n_seqs
        
        # Stack representations of bags (shape (N, d_v))
        emb_seqs_after_attention = torch.stack(mb_emb_seqs_after_attention, dim=0)
        
        # Calculate predictions (shape (N, n_outputs))
        predictions = self.output_nn(emb_seqs_after_attention)
        
        self.mb_attention_weights_list = mb_attention_weights_list
        return predictions

Later you could access the attribute model.mb_attention_weights_list in the evaluation function.

I hope that helps. Let me know if you need further help.
Best wishes, Michael