get_total_memory_used fails to handle list of str
minostauros opened this issue · comments
Minho Shim commented
torchinfo/torchinfo/torchinfo.py
Line 501 in 73ed568
>>> get_total_memory_used(["abc", "def"])
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.9/dist-packages/torchinfo/torchinfo.py", line 503, in get_total_memory_used
result = traverse_input_data(
File "/usr/local/lib/python3.9/dist-packages/torchinfo/torchinfo.py", line 447, in traverse_input_data
result = aggregate(
TypeError: unsupported operand type(s) for +: 'int' and 'str'
>>>
action_fn
is not applied to str so that sys.getsizeof fails to get size of strings.
Minho Shim commented
This happens when an input of a model is a list of strings, e.g., language models.
Possible dirty workaround
def get_total_memory_used(data: CORRECTED_INPUT_DATA_TYPE) -> int:
"""Calculates the total memory of all tensors stored in data."""
result = traverse_input_data(
data,
action_fn=lambda data: sys.getsizeof(
data.untyped_storage()
if hasattr(data, "untyped_storage")
else data.storage()
),
aggregate_fn=(
# We don't need the dictionary keys in this case
# if the data is not integer, assume the above action_fn is not applied for some reason
(
lambda data: (
lambda d: sum(d.values())
if isinstance(d, Mapping)
else sys.getsizeof(d)
)
)
if (isinstance(data, Mapping) or not isinstance(data, int))
else sum
),
)
return cast(int, result)