alxndrTL / mamba.py

A simple and efficient Mamba implementation in pure PyTorch and MLX.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

MLX memory usage at inference

alxndrTL opened this issue · comments

The 1.4B model takes 10-11GB of RAM at inference. (own test, M2 Pro 16GB)
The 2.8B model takes around 50GB at inference. (https://twitter.com/awnihannun/status/1749515431336112275)

This is not due to loading the model from HF (same memory footprint if model initialized with random weights).
This is neither due to the ssm_step.

However, turning off the convolution at inference reduces the memory footprint (by 3GB for the 1.4B model : from 10GB to around 7GB). It also greatly speeds up the inference. (buf of course, the forward is not correct).

Files concerned :

  • mamba_mlx.py (step functions)
  • misc.py

The depthwise conv implemented in misc.py seems to be part of the problem.
As said the file, the PyTorch versions uses groups=channels (true depthwise), while the MLX depthwise conv in misc.py uses groups=1 but with some weights set at 0. (only workaround found).
This result in a (d_model, 4, d_model) filter size, against (d_model, 4) for the "true" depthwise conv.

Either :
-wait for MLX to implement groups=channels for conv1d
-find another workaround (one possibility is to create d_model conv object, each with 1 input and 1 output channel. but this result in a big for loop which is around 45x slower than the workaround found. but ofc, memory usage is greatly reduces (by d_model=2560)