ElisR / SE3Flux

Implementing SE(3)-equivariant neural networks with Flux.jl

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

SE3Flux.jl

Implementing $SE(3)$-equivariant Tensor Field Networks1 (TFN) from scratch with Flux.jl.

This repository is not affiliated with the original paper, and is just an independent implementation. Rather, the official fully-featured library implementations of this work are the written in PyTorch and JAX: e3nn and e3nn-jax. Julia implementations of TFN and its derivatives are (as far as I know) non-existent. This is a first step before I rewrite these equivariant layers in the style of GeometricFlux.jl, where point clouds will be represented by bona fide graphs rather than arrays as they are here.

For those unfamiliar with TFN, I motivate this architecture at the end of this README.

Usage

A disadvantage of TFN is that one must keep track of which rotation representation any given feature vector belongs to. As such, the format for storing feature vectors is inevitably messier than a multidimensional array.

The first choice I have made is to separate positions and feature vectors into a tuple (rrs, Vss) since they are treated differently in the pipeline.

Raw cartesian positions rs of points in the point cloud must first be transformed into pairwise separation vectors, then converted to spherical coordinates. This only needs to be done once, rather than with every forward-pass of the network, so it's more efficient to calculate this outside the network. Because positions and features are stored as arrays, all point clouds in a batch must have the same number of points. For example, a valid array of positions could be generated by the following snippet:

rs = rand(Float32, (num_points, 3, batch_size))
rrs_cart = rs |> pairwise_rs # size(rrs_cart) = (num_points, num_points, 3, batch_size)
rrs_sph = rrs_cart |> cart_to_sph # convert [x, y, z] -> [r, θ, ϕ]

Features have a more complicated structure: they are stored in tuple of vectors. There is one vector for each $\ell$, with each vector holding many channels of feature arrays V. When an $\ell < \ell_{\mathrm{max}}$ is empty, the vector is the empty vector to maintain compatibility. Each V has size (num_points, batch_size, 2ℓ+1). Features compatible with the above positions could be generated by

V11 = ones(Float32, (num_points, batch_size, 3)) # Some ℓ = 1 features
V21 = ones(Float32, (num_points, batch_size, 5)) # Some ℓ = 2 features
Vss = (Vector{typeof(V11)}(undef, 0), [V11], [V21]) # Feature tuple, being careful to use type-stable empty-vector

Some layers make no use of position information, so I have defined a separate gluing structure similar to Flux.jl's Parallel structure that holds multiple parallel layers but acts trivially on the first element of the (rrs, Vss) tuple.

Example

Let's consider the architecture used for the shape classification example in the TFN paper, implemented in /Shape_Classification.ipynb. Here, the aim is to classify a bunch of Tetris-like blocks, of which there are 8 distinct types. The intrinsic rotational invariance of the network means that after being shown just one example of each block, the classifier can be equally confident in recognising the blocks even when they are arbitrarily orientated. Below is a cartoon of the invariance of the output of the entire pipeline with respect to rotation of the input, a special case of equivariance with a trivial identity representation.

A diagram of the network architecture is shown below. We keep track of the rotation representations of the feature vectors with $\ell$ (see README Appendix or TFN paper itself). Applying the convolution is equivalent to taking the tensor product with a filter that has its own rotation representation $\ell_f$. Augmenting a representation $\ell_i$ with a filter $\ell_f$ produces a sum of representations $\ell_o$ in the range $| \ell_i - \ell_f | \leq \ell_o \leq \ell_i + \ell_f$ (resembling a triangle inequality between vectors). As detailed in the paper, special care is taken to ensure that the tensor product transforms appropriately under rotation, by weighting different terms in this sum by so-called Clebsch-Gordan coefficients. In the network below, we choose to discard any $\ell > 1$ terms to save resources.

In this repository, this network is implemented with the code block below. It should hopefully be clear which component corresponds to which. The most important (and complicated) component is the $SE(3)$-equivariant convolution layer E3ConvLayer. The remaining components are non-linear and self-interaction layers (interfaced through NLWrapper and SIWrapper), used to scale feature vectors pointwise and mix channels in an equivariant way. The number of output and input channels for every layer must be correctly specified at the time of construction. (Above, the number of channels with a particular representation is given by $[n]$.) The final three steps are global pooling, dense and softmax layers used for classification of the eight shapes.

# Define centers of radial basis functions
centers = range(0f0, 3.5f0; length=4) |> collect

# `Chain` is the `Flux.jl` constructor for sequential layers
classifier = Chain(
                SIWrapper([1 => 4]),
                E3ConvLayer([4], [[(0, 0) => [0], (0, 1) => [1]]], centers),
                SIWrapper([4 => 4, 4 => 4]),
                NLWrapper([4, 4]),
                E3ConvLayer([4, 4], [[(0, 0) => [0], (0, 1) => [1]],
                                     [(1, 0) => [1], (1, 1) => [0, 1]]], centers),
                SIWrapper([8 => 4, 12 => 4]),
                NLWrapper([4, 4]),
                E3ConvLayer([4, 4], [[(0, 0) => [0]], [(1, 1) => [0]]], centers),
                SIWrapper([8 => 4]),
                NLWrapper([4]),
                PLayer(),
                Dense(4 => 8)
);

Appendix

The Need for Rotation Equivariant Networks

Rotational and translational symmetry are common in nature. It is therefore useful to have neural networks that can exploit this simplified structure of many physical problems, without needing to carry around redundant information. That is, it would nice not to have to relearn the same thing over and over again in different coordinate systems. (For example, a naïve approach to approximating symmetric functions is training networks on augmented data, namely data that has been bulked up with transformed copies. However, in that case no promises can be made about equivariance outside the training dataset.) This is solved by making layers equivariant with respect to a symmetry group by construction (reviewed in detail in the Geometric Deep Learning textbook), meaning that the output transforms appropriately when the input is transformed. There are many ways to design such neural network architectures, often with a trade-off between expressivity and computational cost.

The Tensor Field Network (TFN) is built from matrix representations of rotations and acts on point clouds of features. Translation symmetry is trivially upheld by only ever considering the relative displacement between points. The advantage of TFN is that features can be complex physical quantities (and not just scalars), but this expressivity comes with additional cost compared to some alternatives, especially because it essentially acts on an "all-to-all" graph. Some follow-up works such as SE(3)-Transformers are more performant.

Basis of Tensor Field Networks

A neural network acts on feature vectors, which are sometimes physical quantities. These quantities can transform differently under rotation depending on their rotation representation, indexed by the non-negative integer $\ell$. This distinction is obvious if we compare a scalar (such as an object's mass) to a vector (such as its velocity, where "vector" now refers to a geometric quantity, and not just a one-dimensional array of numbers): One does not change at all under rotation, while the other changes direction. In representation theory (the study of maps from abstract symmetry groups to linear transformations, i.e., matrices), we say that the scalar transforms under the $\ell = 0$ irreducible representation (irrep), whereas the vector transforms under the $\ell = 1$ irrep. (In quantum physics, these matrices act on the wavefunction of a particle, in which case $\ell$ corresponds to its angular momentum.) The matrices in an irreducible representation have dimension $(2\ell + 1) \times (2\ell + 1)$, so the transformation of quantities with higher $\ell$ are more costly to compute.

The README in the test folder outlines the conventions used in this repository for rotation representations of $SO(3)$ (the continuous group generated by composing infinitesimal rotations).

Footnotes

  1. Despite its name, this is completely unrelated "Tensor Networks" used in condensed matter, for which one would use ITensors.jl or a related package.

About

Implementing SE(3)-equivariant neural networks with Flux.jl


Languages

Language:Julia 51.4%Language:Jupyter Notebook 45.0%Language:Dockerfile 3.6%