snimu / torch-nested

Easily manipulate torch.Tensors inside highly nested data-structures.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

torch-nested

Python 3.7+ PyTorch

PyPI Wheel

Tests codecov pre-commit.ci status

License

Easily manipulate torch.Tensors inside highly nested data-structures.

You may want to consider using torch.nested, but if you are working with nested dicts, lists, tuples, etc. of torch.Tensors, here is the package for you.

A proper documentation is coming. Until then, a basic example is shown below, and you can look at the docstrings or tests of this package for more information.

Basic usage

Given a nested structure that contains torch.Tensor, this package makes it easy to access these Tensors and work with them:

import torch
from torch_nested import NestedTensors


INPUT_DATA = [
    (
        torch.ones(3), 
        torch.zeros(2)
    ),
    torch.ones((2, 2, 2)),
    {
        "foo": torch.ones(2), 
        "bar": [], 
        "har": "rar"
    },
    1
]

tensors = NestedTensors(INPUT_DATA)

# Original data preserved in .data-member
assert tensors.data == INPUT_DATA

# Simple accessing and setting
for i, tensor in enumerate(tensors):
    tensors[i] = tensor + i 

# Has basic dunders
assert len(tensors) == 4
assert torch.all(next(tensors) == torch.ones(3))

Calling print(tensors.shape()) would yield:

torch_nested.Size(
  [
    (
      torch.Size([3]),
      torch.Size([2])
    ),
    torch.Size([2, 2, 2]),
    {
      foo: torch.Size([2]),
      bar: None,
      har: None
    },
    None
  ]
)

Supported data-structures

The following data-structures are supported so far:

  • torch.Tensor
  • dict
  • list
  • tuple
  • None
  • Any class with a .tensors-attribute
  • Any class with a .data-attribute, even if it isn't a torch.Tensor

For example

class ObjWithTensors:
    tensors = [torch.ones(2), torch.zeros(2)]

class ObjWithData:
    data = [torch.ones(2), torch.zeros(2)]

tensors = NestedTensors([ObjWithTensors(), ObjWithData()])

Running print(tensors.size()) would result in the following output:

NestedSize(
  [
    ObjWithTensors(
      tensors: [
        torch.Size([2]),
        torch.Size([2])
      ]
    ),
    ObjWithData(
      data: [
        torch.Size([2]),
        torch.Size([2])
      ]
    )
  ]
)

More data-structures will be supported in the future. Any data that is of an unsupported type will not have its Tensors readable or writable, and NestedShape will show None there.

About

Easily manipulate torch.Tensors inside highly nested data-structures.

License:MIT License


Languages

Language:Python 100.0%