[BUG] VisionMambaBlock example
MelihDarcanxyz opened this issue · comments
Describe the bug
VisionMambaBlock example doesn't work. I'm getting:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
test.ipynb Cell 2 line 3
1 block = VisionMambaBlock(dim=256, heads=8, dt_rank=32, dim_inner=512, d_state=256)
2 x = torch.randn(1, 32, 256)
----> 3 out = block(x)
4 out.shape
File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File .venv/lib/python3.10/site-packages/zeta/nn/modules/vision_mamba.py:71, in VisionMambaBlock.forward(self, x)
69 forward_conv_output = self.forward_conv1d(x1_rearranged)
70 forward_conv_output = rearrange(forward_conv_output, "b d s -> b s d")
---> 71 x1_ssm = self.ssm(forward_conv_output)
73 # backward conv x2
74 x2_rearranged = rearrange(x1, "b s d -> b d s")
File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File .venv/lib/python3.10/site-packages/zeta/nn/modules/ssm.py:147, in SSM.forward(self, x, pscan)
145 # Assuming selective_scan and selective_scan_seq are defined functions
146 if pscan:
--> 147 y = selective_scan(x, delta, A, B, C, D)
148 else:
149 y = selective_scan_seq(x, delta, A, B, C, D)
File .venv/lib/python3.10/site-packages/zeta/nn/modules/ssm.py:29, in selective_scan(x, delta, A, B, C, D)
26 deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
27 deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)
---> 29 BX = deltaB * x.unsqueeze(-1) # (B, L, ED, N)
31 hs = pscan(deltaA, BX)
33 y = (
34 hs @ C.unsqueeze(-1)
35 ).squeeze() # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)
RuntimeError: The size of tensor a (512) must match the size of tensor b (256) at non-singleton dimension 2
To Reproduce
Steps to reproduce the behavior:
- Run the example:
block = VisionMambaBlock(dim=256, heads=8, dt_rank=32, dim_inner=512, d_state=256)
x = torch.randn(1, 32, 256)
out = block(x)
out.shape
Expected behavior
torch.Size([1, 32, 256])
Upvote & Fund
- We're using Polar.sh so you can upvote and help fund this issue.
- We receive the funding once the issue is completed & confirmed by you.
- Thank you in advance for helping prioritize & fund our backlog.
Hello there, thank you for opening an Issue ! 🙏🏻 The team was notified and they will get back to you asap.
@MelihDarcanxyz the updated model is here, this is functional: https://github.com/kyegomez/VisionMamba
I need to update this implementation here with the new implementation.
Hi @kyegomez , I saw that but there was parameter named num_classes
so I assumed it was only suitable for classification while this implementation has no such assumptions. I'm looking for something more general. Is it a general block? Sorry, I just couldn't understand and trying to by asking.
EDIT: Tried it, didn't work either. Got the same problem from this issue kyegomez/VisionMamba#4