google / gemma_pytorch

The official PyTorch implementation of Google's Gemma models

Home Page:https://ai.google.dev/gemma

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

MPS (Apple Silicon) Support

dsanmart opened this issue · comments

Will there be MPS support for the Gemma models? It would enable access to a larger community.

Took a look, a few things.

Linux / mps support looks to be in progress still pytorch/pytorch#81224 so running in a container isn't ready yet.

MPS has some limitations around complex tensors atm. Since gemma uses RoPE, it uses complex tensors and errors out if you run it locally.

https://github.com/pytorch/pytorch/pull/116764/files#diff-fe061f10677283971d77576718d3a04a00b2225d72c043fd59222a882b92c64bR654

freqs_cis = self.freqs_cis.index_select(0, input_positions)

Running locally with python scripts/run.py --ckpt gemma-2b-it.ckpt --variant 2b --device mps

pytorch 2.3 has bf16 and complex tensor support, and dockerised containers now work @lamroger