Issues with my wrapper code
StellaAthena opened this issue · comments
Stella Biderman commented
I wrote some wrapper code to turn this layer into a full transformer and I can't seem to figure out what is going wrong. The following works:
import torch
from torch import nn, einsum
import x_transformers
from point_transformer_pytorch import PointTransformerLayer
layer = PointTransformerLayer(
dim = 7,
pos_mlp_hidden_dim = 64,
attn_mlp_hidden_mult = 4,
num_neighbors = 16 # only the 16 nearest neighbors would be attended to for each point
)
feats = torch.randn(1, 5, 7)
pos = torch.randn(1, 5, 3)
mask = torch.ones(1, 5).bool()
y = layer(feats, pos, mask = mask)
However this doesn't work
import torch
from torch import nn, einsum
import x_transformers
from point_transformer_pytorch import PointTransformerLayer
class PointTransformer(nn.Module):
def __init__(self, feats, mask, neighbors = 16, layers=5, dimension=5):
super().__init__()
self.feats = feats
self.mask = mask
self.neighbors = neighbors
self.layers = []
for _ in range(layers):
self.layers.append(PointTransformerLayer(
dim = dimension,
pos_mlp_hidden_dim = 64,
attn_mlp_hidden_mult = 4,
num_neighbors = self.neighbors
))
def forward(self, pos):
curr_pos = pos
for layer in self.layers:
print(curr_pos)
curr_pos = layer(self.feats, pos, self.mask)
print("----")
return curr_pos
model = PointTransformer(feats, mask)
model(pos)
The error I'm getting is mat1 and mat2 shapes cannot be multiplied (5x7 and 5x15)
Stella Biderman commented
NVM I figured it out