huhlim / SE3Transformer

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

SE3Transformer

It is a fork of NVIDIA's SE(3)-Transformer implementation. I made some minor modifications, including

  • removal of torch.cuda.nvtx.nvtx_range
  • addition of the nonlinearity argument to NormSE3, SE3Transformer, and so on.
  • addition of some basic network implementations using SE(3)-Transformer.

Installation

for CPU only

pip install git+http://github.com/huhlim/SE3Transformer

for CUDA (GPU) usage

  1. Install DGL library with CUDA support
# This is an example with cudatoolkit=11.3.
# Set a proper cudatoolkit version that is compatible with your CUDA drivier and DGL library.
conda install dgl -c dglteam/label/cu113
# or
pip install dgl -f https://data.dgl.ai/wheels/cu113/repo.html
  1. Install this package
pip install git+http://github.com/huhlim/SE3Transformer

Code Snippets

  • se3_transformer.LinearModule: LinearSE3 and NormSE3
    class LinearModule(nn.Module):
    """
    Operates only within a node, so it basically applies nn.Linear to every node.
    """
    def __init__(
    self,
    fiber_in: Fiber,
    fiber_hidden: Fiber,
    fiber_out: Fiber,
    n_layer: Optional[int] = 2,
    use_norm: Optional[bool] = True,
    nonlinearity: Optional[nn.Module] = nn.ReLU(),
    **kwargs,
    ):
    """
    arguments:
    - fiber_in: Fiber, numbers of input features
    - fiber_hidden: Fiber, numbers of intermediate features
    - fiber_out: Fiber, numbers of output features
    - n_layer: int, the number linear layers
    - use_norm: bool, if True, NormSE3 will be inserted before a LinearSE3 layer
    - nonlinearity: activation function for NormSE3
    """
    super().__init__()
    #
    linear_module = []
    #
    if n_layer >= 2:
    if use_norm:
    linear_module.append(NormSE3(Fiber(fiber_in), nonlinearity=nonlinearity))
    linear_module.append(LinearSE3(Fiber(fiber_in), Fiber(fiber_hidden)))
    #
    for _ in range(n_layer - 2):
    if use_norm:
    linear_module.append(NormSE3(Fiber(fiber_hidden), nonlinearity=nonlinearity))
    linear_module.append(LinearSE3(Fiber(fiber_hidden), Fiber(fiber_hidden)))
    #
    if use_norm:
    linear_module.append(NormSE3(Fiber(fiber_hidden), nonlinearity=nonlinearity))
    linear_module.append(LinearSE3(Fiber(fiber_hidden), Fiber(fiber_out)))
    else:
    if use_norm:
    linear_module.append(NormSE3(Fiber(fiber_init), nonlinearity=nonlinearity))
    linear_module.append(LinearSE3(Fiber(fiber_init), Fiber(fiber_out)))
    #
    self.linear_module = nn.Sequential(*linear_module)
    def forward(self, x):
    return self.linear_module(x)
  • se3_transformer.InteractionModule: A wrapper of SE3Transformer
    class InteractionModule(nn.Module):
    """
    Utilization of SE3-Transformer block
    """
    def __init__(
    self,
    fiber_in: Fiber,
    fiber_hidden: Fiber,
    fiber_out: Fiber,
    fiber_edge: Optional[Fiber] = Fiber({}),
    n_layer: Optional[int] = 2,
    n_head: Optional[int] = 2,
    use_norm: Optional[bool] = True,
    use_layer_norm: Optional[bool] = True,
    nonlinearity: Optional[nn.Module] = nn.ReLU(),
    low_memory: Optional[bool] = True,
    **kwargs,
    ):
    """
    arguments:
    - fiber_in: Fiber, numbers of input features
    - fiber_hidden: Fiber, numbers of intermediate features
    - fiber_out: Fiber, numbers of output features
    - fiber_edge: Fiber, numbers of edge features
    - n_layer: int, the number linear layers
    - n_head: int, the number of attention heads
    - use_norm: bool, if True, NormSE3 will be inserted before a LinearSE3 layer
    - use_layer_norm: bool, if True, LayerNorm will be used between MLP (radial)
    - nonlinearity: activation function for NormSE3
    - low_memory: bool, if True, gradient checkpoint will be activated for ConvSE3
    """
    super().__init__()
    self.graph_module = SE3Transformer(
    num_layers=n_layer,
    fiber_in=fiber_in,
    fiber_hidden=fiber_hidden,
    fiber_out=fiber_out,
    num_heads=n_head,
    channels_div=2,
    fiber_edge=fiber_edge,
    norm=use_norm,
    use_layer_norm=use_layer_norm,
    nonlinearity=nonlinearity,
    low_memory=low_memory,
    )
    def forward(self, batch: dgl.DGLGraph, node_feats: torch.Tensor, edge_feats: torch.Tensor):
    out = self.graph_module(batch, node_feats=node_feats, edge_feats=edge_feats)
    return out

Usage

  • LinearModule + InteractionModule
    #!/usr/bin/env python
    import torch
    import torch.nn as nn
    import dgl
    import sys
    from se3_transformer import Fiber, LinearModule, InteractionModule
    from se3_transformer.utils import degree_to_dim
    class Model(nn.Module):
    def __init__(self, config):
    super().__init__()
    #
    self.linear = LinearModule(**config["linear"])
    self.interact = InteractionModule(**config["interact"])
    def forward(self, batch: dgl.DGLGraph):
    edge_feats = {}
    node_feats = {str(degree): batch.ndata[f"node_feat_{degree}"] for degree in [0, 1]}
    #
    out = self.linear(node_feats)
    out = self.interact(batch, node_feats=out, edge_feats=edge_feats)
    return out
    def create_random_example(n_point, fiber_in):
    # create a fully connected graph
    edges = [[], []]
    for i in range(n_point):
    for j in range(n_point):
    edges[0].append(i)
    edges[1].append(j)
    edges = tuple([torch.as_tensor(x) for x in edges])
    g = dgl.graph(edges)
    #
    pos = torch.randn((n_point, 3))
    g.ndata["pos"] = pos[:, None, :]
    for fiber in fiber_in:
    dim = degree_to_dim(fiber.degree)
    g.ndata[f"node_feat_{fiber.degree}"] = torch.randn((n_point, fiber.channels, dim))
    #
    src, dst = g.edges()
    g.edata["rel_pos"] = pos[dst] - pos[src]
    return g
    def main():
    config = {}
    #
    config["linear"] = {}
    config["linear"]["fiber_in"] = Fiber([(0, 8), (1, 4)])
    config["linear"]["fiber_hidden"] = Fiber([(0, 16), (1, 8)])
    config["linear"]["fiber_out"] = Fiber([(0, 16), (1, 8)])
    config["linear"]["n_layer"] = 2
    config["linear"]["use_norm"] = True
    config["linear"]["nonlinearity"] = nn.ReLU()
    #
    config["interact"] = {}
    config["interact"]["fiber_in"] = Fiber([(0, 16), (1, 8)])
    config["interact"]["fiber_hidden"] = Fiber([(0, 16), (1, 8)])
    config["interact"]["fiber_out"] = Fiber([(0, 2), (1, 1)])
    config["interact"]["fiber_edge"] = Fiber({})
    config["interact"]["n_layer"] = 2
    config["interact"]["n_head"] = 2
    config["interact"]["use_norm"] = True
    config["interact"]["use_layer_norm"] = True
    config["interact"]["nonlinearity"] = nn.ReLU()
    config["interact"]["low_memory"] = True
    #
    model = Model(config)
    #
    batch = create_random_example(n_point=10, fiber_in=config["linear"]["fiber_in"])
    out = model(batch)
    print(out)
    print(out["0"].size()) # = (n_point, 2, 1)
    print(out["1"].size()) # = (n_point, 1, 3)
    if __name__ == "__main__":
    main()
    In this example,
    • A fully connected graph is created with random coordinates
    • Input features: 8 scalars and 4 vectors
    • Output features: 2 scalars and 1 vector
    • LinearModule: two LinearSE3 with NormSE3, returns 16 scalars and 8 vectors.
    • InteractionModule: two layers of attention blocks with two heads, takes the output of the LinearModule as node_feats and no edge_feats.

About

License:MIT License


Languages

Language:Python 100.0%