TylerYep / torchinfo

View model summaries in PyTorch!

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TypeError: unsupported operand type(s) for +: 'int' and 'NoneType'

YasmineXXX opened this issue · comments

My model requires the following inputs in the forward method:

def forward(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        head_mask=None,
        video_query_tokens=None,
        frame_hidden_state=None,
        frame_atts=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        is_decoder=False,
    ):

However, I only want to test the case when query_embeds, encoder_hidden_states, encoder_attention_mask, and return_dict are set to specific values:

video_query_tokens = torch.zeros(batch_size, 32, 768).to(device)
frame_hidden_state = torch.randn(batch_size, 256, 768, dtype=torch.float32).to(device)
frame_atts = torch.ones(frame_hidden_state.size()[:-1], dtype=torch.long).to(device)
return_dict = True

To do this, I used the following statement:

summary(model.video_Qformer.bert, input_data=[None, None, None, None, video_query_tokens, frame_hidden_state, frame_atts, None, None, None, None, True, False])

However, I received an error message:

╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│ /mnt/wfs/mmchongqingwfssz/user_autumnywang/baseline/Video-LLaMA/qformer_mode │
│ l_print.py:81 in <module>                                                    │
│                                                                              │
│   78 video_query_tokens = torch.zeros(batch_size, 32, 768).to(device)        │
│   79 frame_hidden_state = torch.randn(batch_size, 256, 768, dtype=torch.floa │
│   80 frame_atts = torch.ones(frame_hidden_state.size()[:-1], dtype=torch.lon │
│ ❱ 81 summary(model.video_Qformer.bert, input_data=[None, None, None, None, v │
│   82 # summary(model.video_Qformer.bert, input_size=[None, None, None, None, │83 # summary(model, input_size=(16, 3, 8, 224, 224))                       │84                                                                         │
│                                                                              │
│ /opt/conda/lib/python3.9/site-packages/torchinfo/torchinfo.py:228 in summary │
│                                                                              │
│   225 │   )                                                                  │
│   226formatting = FormattingOptions(depth, verbose, columns, col_width, │
│   227results = ModelStatistics(                                         │
│ ❱ 228 │   │   summary_list, correct_input_size, get_total_memory_used(x), fo │
│   229 │   )                                                                  │
│   230if verbose > Verbosity.QUIET:                                      │
│   231 │   │   print(results)                                                 │
│                                                                              │
│ /opt/conda/lib/python3.9/site-packages/torchinfo/torchinfo.py:503 in         │
│ get_total_memory_used                                                        │
│                                                                              │
│   500                                                                        │
│   501 def get_total_memory_used(data: CORRECTED_INPUT_DATA_TYPE) -> int:     │
│   502"""Calculates the total memory of all tensors stored in data."""   │
│ ❱ 503result = traverse_input_data(                                      │
│   504 │   │   data,                                                          │
│   505 │   │   action_fn=lambda data: sys.getsizeof(                          │
│   506 │   │   │   data.untyped_storage()                                     │
│                                                                              │
│ /opt/conda/lib/python3.9/site-packages/torchinfo/torchinfo.py:447 in         │
│ traverse_input_data                                                          │
│                                                                              │
│   444 │   │   )                                                              │
│   445elif isinstance(data, Iterable) and not isinstance(data, str):     │
│   446 │   │   aggregate = aggregate_fn(data)                                 │
│ ❱ 447 │   │   result = aggregate(                                            │
│   448 │   │   │   [traverse_input_data(d, action_fn, aggregate_fn) for d in  │
│   449 │   │   )                                                              │
│   450else:                                                              │
╰──────────────────────────────────────────────────────────────────────────────╯
TypeError: unsupported operand type(s) for +: 'int' and 'NoneType'

How to solve this case?