Issue with torch.compile
scopello opened this issue · comments
Hi @lucidrains,
I am trying to use torch.compile() with a model that wraps two x-transformer Encoders. When I run the following minimal example:
import torch
import torch.nn as nn
from x_transformers import Encoder
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.encoder_1 = Encoder(dim=32, depth=2, heads=2)
self.encoder_2 = Encoder(dim=32, depth=2, heads=2)
def forward(self, x_1, x_2):
out_1 = self.encoder_1(x_1)
out_2 = self.encoder_2(x_2)
return torch.cat([out_1, out_2], 1)
model = MyModel().cuda()
seq_len_1 = 8
seq_len_2 = 16
x_1 = torch.randn([1, seq_len_1, 32]).cuda()
x_2 = torch.randn([1, seq_len_2, 32]).cuda()
# Compile the model.
model = torch.compile(model)
out = model(x_1, x_2)
I get error:
TorchRuntimeError: Failed running call_function (*(FakeTensor(..., device='cuda:0', size=(1, s0, 128), grad_fn=), 'b n (h d) -> b h n d'), **{'h': 2}): unhashable type: non-singleton SymInt
Which comes from:
https://github.com/lucidrains/x-transformers/blob/2a0ec67fbdad18d2bd5f8bf3d9bc20e705a58a6b/x_transformers/x_transformers.py#L801
Surprisingly, the model compiles successfully if I set seq_len_2 = seq_len_1
, but I don't know why.
I am using einops 0.7.0rc1 and pytorch 2.1.0
Thanks!
Seems to work fine on A100, but not H100.
ah nice, yea that seems like an einops / pytorch specific error, but not entirely sure
what is your use-case btw? that's a really interesting network
oh, are you doing two towers architecture?
Thanks! This is for model that requires encoders for 2 different modalities. Btw, would you expect any significant speedup by using torch.compile if flash attention is enabled?