QUVA-Lab / e2cnn

E(2)-Equivariant CNNs Library for Pytorch

Home Page:https://quva-lab.github.io/e2cnn/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Feature Request: Slicing of GeometricTensors

drewm1980 opened this issue · comments

I'm porting over this function to e2cnn:

    def center_crop(self, layer:torch.Tensor, target_size:List[int]):
        _, _, layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0]) // 2
        diff_x = (layer_width - target_size[1]) // 2
        return layer[
            :, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])
        ]

When I pass a GeometricTensor into it, I hit "TypeError: 'GeometricTensor' object is not subscriptable" where the tensor is sliced.

That's probably too much info. I only need to slice along the spatial dimensions... I will experiment with casting back to torch.Tensor, but please let me know if you have a better workaround or update up your sleave. Thanks!

This gets me past the error; will know later if it actually works.

    def center_crop(self, geometric_tensor:nn.GeometricTensor, target_size:List[int])->nn.GeometricTensor:

        # Unpack to a tensor so we can slice
        tensor:torch.Tensor = geometric_tensor.tensor
        field_type:nn.FieldType = geometric_tensor.type

        _, _, tensor_height, tensor_width = tensor.size()
        diff_y = (tensor_height - target_size[0]) // 2
        diff_x = (tensor_width - target_size[1]) // 2
        tensor_sliced = tensor[
            :, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])
        ]
        
        # Repack into a GeometricTensor
        geometric_tensor_sliced = nn.GeometricTensor(tensor_sliced, field_type)

        return geometric_tensor_sliced

I'm still new to torch's memory model regarding slicing. Will passing a sliced tensor back into GeometricTensor's instantiator cause problems?

Update: Casting back to a torch.Tensor as above is inelegant, but it seems to work. It's no longer urgent for me, but I'll leave the ticket open in case you have a more elegant solution for indexing. Was it a deliberate decision to disable indexing so that users don't mess up the channels dimension in a way that would break equivariance? If so, maybe there's a way to allow indexing as long as the caller only slices on the spatial dimensions...

Hi @drewm1980,

sorry for my late reply.

Yeah, GeometricTensor does not support indexing. A simple solution is unwrapping the underlying torch.Tensor, crop that and wrap it again in a GeometricTensor. I usually do the same.
When instantiated, a GeometricTensor doesn't do much more than storing a reference to the torch.Tensor and a reference to the FieldType. So, I don't think you need to worry about memory management.

Was it a deliberate decision to disable indexing so that users don't mess up the channels dimension in a way that would break equivariance? If so, maybe there's a way to allow indexing as long as the caller only slices on the spatial dimensions...

Yes, channels can not be split freely as it could break equivariance. If you need to only slice in the channel dimension you can use GeometricTensor.split().

For the spatial dimensions, it was not a common use case for me, so I did not implement any additional interface. I agree it deserves a better solution, though.
I think I can override the brackets operator (the one usually used for indexing) for GeometricTensor. I will try to implement it this way and let you know if it works.

Best,
Gabriele

Supporting full slicing of the underlying tensor is dangerous as it can split channels which belong to the same field. The method GeometricTensor.split() can be used for this purpose.

I can implement slicing along the batch and the spatial dimensions adding the following method in the GeometricTensor class:

def __getitem__(self, slices):        
    # Slicing is not supported on the channel dimension.
    if isinstance(slices, tuple):
        if len(slices) > len(self.tensor.shape) - 1:
            raise TypeError()
    else:
        slices = (slices,)

    # This is equivalent to use [:] on the channels dimensions
    idxs = (slices[0], slice(None, None, None), *slices[1:])
    sliced_tensor = self.tensor[idxs]
    return GeometricTensor(sliced_tensor, self.type)

This would allow slicing as you would usually do it in PyTorch or Numpy, but it would skip the channel dimension when multiple indices are passed.
This is an example:

space = Rot2dOnR2(4)
type = FieldType(space, [space.regular_repr])
geom_tensor = GeometricTensor(torch.randn(10, type.size, 7, 7), type)

geom_tensor.shape
>> torch.Size([10, 4, 7, 7])

geom_tensor[1:3, 2:5, 2:5].shape
>> torch.Size([2, 4, 3, 3])

Here, I've passed 3 indices which are then used for the first (batch), third and fourth (spatial) dimensions, skipping the second (channels) one.
Do you think this is a valid solution or would this behaviour be confusing?

I appreciate any feedback and if anyone has a better suggestion, feel free to write it here!

Thanks!
Gabriele

Thanks for the feedback @drewm1980, slicing is indeed something which we should implement but did not do yet due to time constraints.

I would not go for the solution of skipping the channel axis since the result might be unexpected for unexperienced users. It seems better to make the channel dimension explicit while guaranteeing that the user can't break the equivariance (i.e. split within fields).
I see three options:

  1. We enforce that no splits are allowed in the channel dimension, i.e. that the slices are necessarily of the form [N, :, X, Y].
  2. Slices are of the form [N, C, X, Y] where C counts channels. The method throws an exception if C splits within a field.
  3. Slices are of the form [N, F, X, Y] where F counts fields. This would be similar to the behavior of GeometricTensor.split(). This seems most logical from the viewpoint of steerable CNNs. The downside is that might intuitively expect the behavior of option 2) and get confused.

I personally really like option 3 as it seems a very clean solution and could also replace GeometricTensor.split().
To be more precise, it would produce this behaviour:

space = Rot2dOnR2(4)
type = FieldType(space, [space.regular_repr]*2 + [space.irrep(1)]*3 )
geom_tensor = GeometricTensor(torch.randn(10, type.size, 7, 7), type)

geom_tensor.shape
>> torch.Size([10, 14, 7, 7])

geom_tensor[1:3, :, 2:5, 2:5].shape
>> torch.Size([2, 14, 3, 3])

# the first 2 fields are regular fields of size 4. In total, they contain 2*4 = 8 channels
geom_tensor[:, :2, :, :].shape
>> torch.Size([10, 8, 7, 7])

# the last 2 fields are vector fields of size 2. In total, they contain 2*2 =4 channels
geom_tensor[:, -2:, :, :].shape
>> torch.Size([10, 4, 7, 7])

# the first 3 fields are 2 regular and 1 vector field. In total, they contain 2*4 + 2 =10 channels
geom_tensor[:, :3, :, :].shape
>> torch.Size([10, 10, 7, 7])

This could also simplify the use of MultipleModule.

Option 1 would probably be the most user-friendly for new users.

Allright, lets go for option 3 then. It includes option 1 and if the user really wants to slice within fields, he can do it by the workaround proposed by @drewm1980.

We now support slicing on all axes and simple indexing (i.e. with a single index per dimension).
Slicing in the second dimension is done over fields instead of channels.
We do not support advanced indexing, though.

You can find some examples and additional details here.

I hope this can be helpful!

Best,
Gabriele

Thanks for the improvement! I'm trying it out. I needed this again today, and it was again involving cropping in the spatial dimensions to make tensors compatible for concatenation (for some skip connections in my network). Also, I agree with the decision to keep indexing numpy/pytorch compatible as much as possible.

I trained a new "champion" network that was using this internally today for skip connections. I haven't uncovered any issues related to this yet. Thanks!

Hi @drewm1980

good to hear!

Thanks for the feedback! We really appreciate it!

Please, let us know if you encounter any issues!