dgcnz / relaxed-equivariance-dynamics

Code for "Effect of equivariance on training dynamics"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Setup Relaxed Group Convolutional Network

dgcnz opened this issue · comments

Implementation change: Replace rot_img with torchvision.transforms.functional.rotate.

Relevant code:

def rot_img(x: Tensor, theta: float) -> Tensor:
    """ Rotate batch of images by `theta` radians.

    :param x: batch of images with shape [N, C, H, W]
    :param theta: angle
    :returns rotated images
    """
    rot_mat = FloatTensor(
        [
            [np.cos(theta), -np.sin(theta), 0],
            [np.sin(theta), np.cos(theta), 0],
        ]
    )
    rot_mat = rot_mat.repeat(x.shape[0], 1, 1)
    grid = F.affine_grid(rot_mat.to(x.device), x.size(), align_corners=False).float()
    x = F.grid_sample(x, grid)
    return x.float()

Example:
image
image

Qualitative difference mask:
image

Benchmarks:
image

Conclusion:

  • I will use TTF.rotate because the qualitative differences are negligible and the performance gains is consistent.

Second optimization (for the lifting layer):
Replacing:

torch.einsum("na, noa... -> oa...", relaxed_weights, filter_bank)

For this:

relaxed_weights.view(num_filter_banks, 1, group_order, 1, 1, 1) * filter_bank).sum(0)

Makes it considerably faster. Benchmark code is on tests/models/components/gcnn/lifting/test_relaxed_rotation.py

image

A couple of weird things about the implementation of the weighted combination for the relaxed group convolution (not the lifting layer). The relaxed_weights now have shape (group_order, num_filter_banks), which is the transpose of the lifting layer's weights. I'm not sure if this is an actual modeling choice or just a random thing, but it requires an extra transpose function. However, this doesn't seem to affect performance.

    def fast():
        return torch.sum(
            relaxed_weights.transpose(0, 1).view(
                num_filter_banks, 1, group_order, 1, 1, 1, 1
            )
            * filter_bank,
            dim=0,
        )

    def fast_group_last():
        return torch.sum(
            relaxed_weights.view(num_filter_banks, 1, group_order, 1, 1, 1, 1)
            * filter_bank,
            dim=0,
        )

    def einsum():
        return torch.einsum("na, aon... -> on...", relaxed_weights, filter_bank)

The results for the group convolution are not that striking for cpu, but it still noticeable ~3x performance.

image

Observations:

  • Now mps is faster than cpu, which was the opposite for the lifting layer, maybe this is because this has an extra dimension and thus requires more parallelization? This can be somehow supported by the fact that the means of each test are considerably different (7 vs 20).
  • (group_order, num_filter_banks) and (num_filter_banks, group_order) have similar performances
  • einsum sucks

It would be nice that the numbers in these tests were more atuned to real architectures, maybe this will be a task for the future.