facebookresearch / torchdim

Named tensors with first-class dimensions for PyTorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Partitioning groups of tensors based on dimensions

ThomasHeap opened this issue · comments

First off thanks for working on this, it's really neat.

Previously I was using named tensors to keep track of various 'types' of dimension, for instance I'd have dimensions which all started with 'K_' which allowed for checks like:

has_K = [n for n in tensor.names if 'K_' in n]
no_K = [n for n in tensor.names if 'K_' not in n]

and have two problems/questions:

Vmap doesn't allow the use of tensors in data-dependent control flow

Even if I keep track of all the 'K' type dims in a list I run into some issues:

I have some code like:

[n for n in tensor.dims if n in list_of_Kdims]

From which I get a error like:

RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow. We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 .

Is there a way of doing this sort of thing?

More generally: Can dims be identified or created programatically?

I'd like to be able to do something like:

K = dim(1)
dim_tensor = tensor[K,:]
tensor_dict = {'key':dim_tensor}

# then some function that does for each key:
K_ + 'key' = dim(1)
tensor_dict['key'].index(K,K_key)

# Or some function that does for each key:
for dim in tensor_dict['key'].dims:
    if dim has name like K_'key':
        return True

Now obviously this will not work, but I'd like to be be able to create dims that can then later be identified as being of type 'K' without having to keep a dictionary or list of all such dims.

Is something like this possible?

The first issue is happening because the in operator is using equality to test if one dimension is another. Since dimension equality actually treats the dimensions as tensors, i == j produces a new tensor with dimensions I and j and True along its diagonal. This then confuses the in operator. A workaround is to do the test as

def hasdim(x, lst):
  for element in lst:
    if  x is element:
      return True
  return False

The error message could be improved in this case. It would be nice for the standard in checks to work but it is hard to have both these checks working as expected and equality of dimensions to behave like other comparison operators.

For the second issue: dim objects have a name field for debugging. Would letting that name field be set programmatically help with the desired behavior?

Ah ok, thats useful to know thank you.

I've managed to mostly overcome the second issue by using setattr to give the dims boolean attributes that distinguish the different "types" of dimensions I want. If the name field was used would this be reflected if the tensor was printed?

If I had a list of strings and wanted to create dims with these strings as names, would this be possible?

We can add way to set the name field of a dimension. It would appear as the name when the dimension is printed. I think adding attributes to the Dim as an additional way to store information is fine too.