TylerYep / torchinfo

View model summaries in PyTorch!

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

get_total_memory_used fails to handle list of str

minostauros opened this issue · comments

def get_total_memory_used(data: CORRECTED_INPUT_DATA_TYPE) -> int:

>>> get_total_memory_used(["abc", "def"])
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.9/dist-packages/torchinfo/torchinfo.py", line 503, in get_total_memory_used
    result = traverse_input_data(
  File "/usr/local/lib/python3.9/dist-packages/torchinfo/torchinfo.py", line 447, in traverse_input_data
    result = aggregate(
TypeError: unsupported operand type(s) for +: 'int' and 'str'
>>> 

action_fn is not applied to str so that sys.getsizeof fails to get size of strings.

This happens when an input of a model is a list of strings, e.g., language models.

Possible dirty workaround

def get_total_memory_used(data: CORRECTED_INPUT_DATA_TYPE) -> int:
    """Calculates the total memory of all tensors stored in data."""
    result = traverse_input_data(
        data,
        action_fn=lambda data: sys.getsizeof(
            data.untyped_storage()
            if hasattr(data, "untyped_storage")
            else data.storage()
        ),
        aggregate_fn=(
            # We don't need the dictionary keys in this case
            # if the data is not integer, assume the above action_fn is not applied for some reason
            (
                lambda data: (
                    lambda d: sum(d.values())
                    if isinstance(d, Mapping)
                    else sys.getsizeof(d)
                )
            )
            if (isinstance(data, Mapping) or not isinstance(data, int))
            else sum
        ),
    )
    return cast(int, result)