google / gemma_pytorch

The official PyTorch implementation of Google's Gemma models

Home Page:https://ai.google.dev/gemma

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Inconsistency between PyTorch and JAX implementation

aboros98 opened this issue · comments

Hello!

In the PyTorch implementation, in the MLP, exact GeLU is used as a gating function.

image

image

In the JAX version, the approximate gelu is used.

image
image

Could you please clarify which version is the correct one?

I see this PR is fixing it, will help land it soon.
#37

PR is checked in. Closing this.