kyegomez / zeta

Build high-performance AI models with modular building blocks

Home Page:https://zeta.apac.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[BUG] VisionMambaBlock example

MelihDarcanxyz opened this issue · comments

Describe the bug
VisionMambaBlock example doesn't work. I'm getting:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
test.ipynb Cell 2 line 3
      1 block = VisionMambaBlock(dim=256, heads=8, dt_rank=32, dim_inner=512, d_state=256)
      2 x = torch.randn(1, 32, 256)
----> 3 out = block(x)
      4 out.shape

File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File .venv/lib/python3.10/site-packages/zeta/nn/modules/vision_mamba.py:71, in VisionMambaBlock.forward(self, x)
     69 forward_conv_output = self.forward_conv1d(x1_rearranged)
     70 forward_conv_output = rearrange(forward_conv_output, "b d s -> b s d")
---> 71 x1_ssm = self.ssm(forward_conv_output)
     73 # backward conv x2
     74 x2_rearranged = rearrange(x1, "b s d -> b d s")

File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File .venv/lib/python3.10/site-packages/zeta/nn/modules/ssm.py:147, in SSM.forward(self, x, pscan)
    145 # Assuming selective_scan and selective_scan_seq are defined functions
    146 if pscan:
--> 147     y = selective_scan(x, delta, A, B, C, D)
    148 else:
    149     y = selective_scan_seq(x, delta, A, B, C, D)

File .venv/lib/python3.10/site-packages/zeta/nn/modules/ssm.py:29, in selective_scan(x, delta, A, B, C, D)
     26 deltaA = torch.exp(delta.unsqueeze(-1) * A)  # (B, L, ED, N)
     27 deltaB = delta.unsqueeze(-1) * B.unsqueeze(2)  # (B, L, ED, N)
---> 29 BX = deltaB * x.unsqueeze(-1)  # (B, L, ED, N)
     31 hs = pscan(deltaA, BX)
     33 y = (
     34     hs @ C.unsqueeze(-1)
     35 ).squeeze()  # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)

RuntimeError: The size of tensor a (512) must match the size of tensor b (256) at non-singleton dimension 2

To Reproduce
Steps to reproduce the behavior:

  1. Run the example:
block = VisionMambaBlock(dim=256, heads=8, dt_rank=32, dim_inner=512, d_state=256)
x = torch.randn(1, 32, 256)
out = block(x)
out.shape

Expected behavior

torch.Size([1, 32, 256])

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar

Hello there, thank you for opening an Issue ! 🙏🏻 The team was notified and they will get back to you asap.

@MelihDarcanxyz the updated model is here, this is functional: https://github.com/kyegomez/VisionMamba

I need to update this implementation here with the new implementation.

Hi @kyegomez , I saw that but there was parameter named num_classes so I assumed it was only suitable for classification while this implementation has no such assumptions. I'm looking for something more general. Is it a general block? Sorry, I just couldn't understand and trying to by asking.

EDIT: Tried it, didn't work either. Got the same problem from this issue kyegomez/VisionMamba#4