TylerYep / torchinfo

View model summaries in PyTorch!

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

summary not working with a custom pre-processing layer

jaroslawjanas opened this issue · comments

Describe the bug
I have a custom TextVectorization layer, it doesn't use any nn. layers, it's just a dictionary of words that is used to fill in a torch.zeros vector. I want it to be baked in so I put it as the first layer in my model.

Unfortunately, it doesn't work with torchinfo.summary(mode, input_shape(["test"] * batch_size).
Which is bothersome.

The model.forward(["this is a test"]) works just fine so I am somewhat confident that it's an issue with torchinfo not being able to handle my custom layer. It worked fine without it (with random int tokens as input data).

Code and Setup
TextVectorization

class TextVectorization(nn.Module):
    def __init__(self, max_vocabulary, max_tokens):
        super(TextVectorization, self).__init__()
        self.max_tokens = max_tokens
        self.max_vocabulary = max_vocabulary
        self.word_dictionary = dict()
        self.dictionary_size = 0

    def adapt(self, dataset):
        word_frequencies = defaultdict(int)

        for text in dataset:
            for word in text[0].split():
                word_frequencies[word] += 1

        # Sort the dictionary by word frequencies in descending order
        sorted_word_frequencies = dict(sorted(word_frequencies.items(),
                                              key=lambda item: item[1],
                                              reverse=True)
        )

        # Take the top N most frequent words
        most_frequent = list(sorted_word_frequencies.items())[:self.max_vocabulary]
        self.dictionary_size = len(most_frequent)

        # Note starting at 2 since 0 (padding) and 1 (missing) are reserved
        for word_value, (word, frequency) in enumerate(most_frequent, 2):
            self.word_dictionary[word] = word_value

    def vocabulary_size(self):
        return self.dictionary_size

    def dictionary(self):
        return self.word_dictionary

    def forward(self, batch_x):
        batch_text_vectors = torch.zeros((len(batch_x), self.max_tokens), dtype=torch.int32)

        for i, text in enumerate(batch_x):

            # Split the text and tokenize it
            words = text.split()[:self.max_tokens]

            for pos, word in enumerate(words):
                batch_text_vectors[i, pos] = self.word_dictionary.get(word, 1)

        return batch_text_vectors

Model

class TransformerModel(nn.Module):
    def __init__(self, max_tokens, vocab_size, embed_dim, num_heads, ff_dim, vectorize_layer):
        super(TransformerModel, self).__init__()
        self.vectorize_layer = vectorize_layer
        self.embedding_layer = TokenAndPositionEmbedding(
            max_tokens,
            vocab_size,
            embed_dim
        )
        self.transformer_block = TransformerBlock(embed_dim, num_heads, ff_dim)
        self.global_avg_pooling = nn.AdaptiveAvgPool1d(1)
        self.dropout = nn.Dropout(0.1)
        self.fc1 = nn.Linear(embed_dim, 20)
        self.fc2 = nn.Linear(20, 3)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.vectorize_layer(x)
        x = self.embedding_layer(x)
        x = self.transformer_block(x)
        x = self.global_avg_pooling(x.permute(0, 2, 1)).squeeze(2)
        x = self.dropout(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.softmax(x)
        return x

Summary

summary_samples = ["This is a test"] * batch_size
summary(model, input_data=summary_samples)

Runtime Error

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[/usr/local/lib/python3.10/dist-packages/torchinfo/torchinfo.py](https://localhost:8080/#) in forward_pass(model, x, batch_dim, cache_forward_pass, device, mode, **kwargs)
    294             if isinstance(x, (list, tuple)):
--> 295                 _ = model(*x, **kwargs)
    296             elif isinstance(x, dict):

4 frames
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1567 
-> 1568             result = forward_call(*args, **kwargs)
   1569             if _global_forward_hooks or self._forward_hooks:

TypeError: TransformerModel.forward() takes 2 positional arguments but 257 were given

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
[<ipython-input-107-41c92f8997e3>](https://localhost:8080/#) in <cell line: 5>()
      3 summary_samples = ["This is a test"] * batch_size
      4 # print(np.shape(summary_samples))
----> 5 summary(model, input_data=summary_samples)
      6 
      7 

[/usr/local/lib/python3.10/dist-packages/torchinfo/torchinfo.py](https://localhost:8080/#) in summary(model, input_size, input_data, batch_dim, cache_forward_pass, col_names, col_width, depth, device, dtypes, mode, row_settings, verbose, **kwargs)
    221         input_data, input_size, batch_dim, device, dtypes
    222     )
--> 223     summary_list = forward_pass(
    224         model, x, batch_dim, cache_forward_pass, device, model_mode, **kwargs
    225     )

[/usr/local/lib/python3.10/dist-packages/torchinfo/torchinfo.py](https://localhost:8080/#) in forward_pass(model, x, batch_dim, cache_forward_pass, device, mode, **kwargs)
    302     except Exception as e:
    303         executed_layers = [layer for layer in summary_list if layer.executed]
--> 304         raise RuntimeError(
    305             "Failed to run torchinfo. See above stack traces for more details. "
    306             f"Executed layers up to: {executed_layers}"

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: []

Screenshots
If applicable, add screenshots to help explain your problem.

Desktop (please complete the following information):

  • Colab
  • PyTorch = 2.1.0+cu118
  • torchinfo = 1.8.0

This is related to #254 and probably also #280. The code expects "tensor-like" input, not strings. Even if this isn't fixed, the error should definitely be caught earlier and stated more clearly. As it stands, process_input doesn't know what to do with this kind of input, and there are related issues coming from traverse_input_data.

I would love to work on this. Does anyone have opinions on what should be done: new functionality to handle text input, or a better error message?

Either solution sounds good to me. The better error message sounds like a good place to start, and then handling text input would be a good followup. PRs are definitely welcome!