MLX memory usage at inference
alxndrTL opened this issue · comments
The 1.4B model takes 10-11GB of RAM at inference. (own test, M2 Pro 16GB)
The 2.8B model takes around 50GB at inference. (https://twitter.com/awnihannun/status/1749515431336112275)
This is not due to loading the model from HF (same memory footprint if model initialized with random weights).
This is neither due to the ssm_step
.
However, turning off the convolution at inference reduces the memory footprint (by 3GB for the 1.4B model : from 10GB to around 7GB). It also greatly speeds up the inference. (buf of course, the forward is not correct).
Files concerned :
mamba_mlx.py
(step
functions)misc.py
The depthwise conv implemented in misc.py
seems to be part of the problem.
As said the file, the PyTorch versions uses groups=channels (true depthwise), while the MLX depthwise conv in misc.py
uses groups=1 but with some weights set at 0. (only workaround found).
This result in a (d_model, 4, d_model) filter size, against (d_model, 4) for the "true" depthwise conv.
Either :
-wait for MLX to implement groups=channels for conv1d
-find another workaround (one possibility is to create d_model
conv object, each with 1 input and 1 output channel. but this result in a big for loop which is around 45x slower than the workaround found. but ofc, memory usage is greatly reduces (by d_model
=2560)