Phi-3 128K Context Variants' `su` RoPE Scaling
JosefAlbers opened this issue · comments
Hi, I've noticed that the current implementation of the Phi-3 model in the mlx-lm repository seems to only support linear RoPE scaling. This limitation I think prevents the longer context variants of the Phi-3 model (including the Phi-3-vision model) from functioning correctly.
In my work porting the Phi-3-vision model to MLX, I've written code to implement "su"-scaled RoPE. I'd be happy to create a pull request (PR) to add this functionality to the phi3.py file, allowing the longer context Phi-3 models to work as intended.
class Phi3SuScaledRotaryEmbedding(nn.Module):
def __init__(self, dim, config):
self.dim = dim
self.base = config.rope_theta
self.short_factor = config.rope_scaling["short_factor"]
self.long_factor = config.rope_scaling["long_factor"]
self.original_max_position_embeddings = config.original_max_position_embeddings
self.scaling_factor = math.sqrt(1 + math.log(config.max_position_embeddings / config.original_max_position_embeddings) / math.log(config.original_max_position_embeddings))
self.inv_freq = None
def __call__(self, position_ids):
seq_len = position_ids.max() + 1
ext_factors = mx.array(self.long_factor, dtype=mx.float32) if seq_len > self.original_max_position_embeddings else mx.array(self.short_factor, dtype=mx.float32)
inv_freq_shape = mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim
self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
inv_freq_expanded = mx.repeat(self.inv_freq[None, :, None], position_ids.shape[0], axis=0)
position_ids_expanded = mx.array(position_ids, dtype=mx.float32)[:, None, :]
freqs = mx.matmul(inv_freq_expanded, position_ids_expanded).transpose(0, 2, 1)
emb = mx.concatenate([freqs, freqs], axis=-1)
cos = mx.cos(emb) * self.scaling_factor
sin = mx.sin(emb) * self.scaling_factor
return cos, sin
Benefits of Adding Su-scaled RoPE:
- Enable Long Context Variants: This change would immediately make the longer context versions of the Phi-3 model usable within MLX.
- Improved Performance: Su-scaled RoPE has been shown to generally improve performance over linear scaling, especially for longer sequences.
- Alignment with Original Model: Ensures the MLX implementation accurately reflects the design of the original Phi-3 models.
Please let me know if you'd be interested in a PR to add this functionality. I'm happy to discuss this further and provide any necessary details.
Can you please help where i can fit in the libarary code. i have seen your code, but if i want to use original library and fit this code. can you please help thanks...
Sorry for the delayed response. I definitely think we should add this to the model files as they are presumably incorrect now for long context. Would you mind sending a PR @JosefAlbers ?
At the moment running very long context might hit memory limitations, though I'm hopeful our forth-coming fused attention will help there.
In my work porting the Phi-3-vision model to MLX
PS that is a very cool project! Is it functional? Do you mind if I share it more broadly?
Can you please help where i can fit in the libarary code. i have seen your code, but if i want to use original library and fit this code. can you please help thanks...
@mustangs0786, I'm currently working on integrating su-RoPE scaling directly into the Phi-3 model and plan to submit a pull request (PR) soon. In the meantime, you can try this temporary workaround within the Attention module's init method:
class Attention(nn.Module):
def __init__(self, args):
# ...
if args.rope_scaling is not None:
if args.rope_scaling["type"] == "linear":
rope_scale = 1 / args.rope_scaling["factor"]
self.rope = nn.RoPE(
args.head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
elif args.rope_scaling["type"] == "su":
self.rope = Phi3SuScaledRotaryEmbedding(args.head_dim, args)
Sorry for the delayed response. I definitely think we should add this to the model files as they are presumably incorrect now for long context. Would you mind sending a PR @JosefAlbers ?
@awni My pleasure, I'll begin working on the PR shortly.
At the moment running very long context might hit memory limitations, though I'm hopeful our forth-coming fused attention will help there.
That would be fantastic.
PS that is a very cool project!
Wow, thank you!
Is it functional?
The project is at this point functional in several key tasks, including image captioning, batched generation, LoRA training, and model/cache quantization. You can find more details in my README.md.
Do you mind if I share it more broadly?
That would be very kind of you, thank you so much!
Oh, and the su-RoPE is a bit different from how it was when I originally posted it last week. It's now as following:
class Phi3SuScaledRotaryEmbedding:
def __init__(self, dim, config, **kwargs):
self.inv_freq_short = 1.0 / (mx.array(config.rope_scaling["short_factor"], dtype=mx.float32) * config.rope_theta**(mx.arange(0, dim, 2, dtype=mx.float32) / dim))
self.inv_freq_long = 1.0 / (mx.array(config.rope_scaling["long_factor"], dtype=mx.float32) * config.rope_theta**(mx.arange(0, dim, 2, dtype=mx.float32) / dim))
self.original_max_position_embeddings = config.original_max_position_embeddings
self.scaling_factor = math.sqrt(1 + math.log(config.max_position_embeddings / config.original_max_position_embeddings) / math.log(config.original_max_position_embeddings))
def _get_cos_sin(self, offset, L, pids):
def _get_pids(offset, L, pids):
if offset < 1:
return pids
return pids[:, -1][:, None] + offset - pids.shape[1] + 2 + mx.arange(L)[None, :]
position_ids = mx.arange(offset, offset+L, dtype=mx.float32)[None] if pids is None else _get_pids(offset, L, pids)
inv_freq = self.inv_freq_long if position_ids.max()+1 > self.original_max_position_embeddings else self.inv_freq_short
inv_freq_expanded = mx.repeat(inv_freq[None, :, None], position_ids.shape[0], axis=0)
position_ids_expanded = position_ids[:, None, :]
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(0, 2, 1)
emb = mx.concatenate([freqs, freqs], axis=-1)
cos = mx.cos(emb) * self.scaling_factor
sin = mx.sin(emb) * self.scaling_factor
return mx.expand_dims(cos, axis=1), mx.expand_dims(sin, axis=1)
def __call__(self, q, k=None, offset=0, pids=None):
def _rotate_half(x):
midpoint = x.shape[-1] // 2
x1, x2 = x[..., :midpoint], x[..., midpoint:]
return mx.concatenate([-x2, x1], axis = -1)
cos, sin = self._get_cos_sin(offset, q.shape[2], pids)
return (q * cos) + (_rotate_half(q) * sin) if k is None else (q * cos) + (_rotate_half(q) * sin), (k * cos) + (_rotate_half(k) * sin)
@JosefAlbers Hi i tried implementing,
`class Attention(nn.Module):
def init(self, args: ModelArgs):
super().init()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim)
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear":
rope_scale = args.rope_scaling["factor"]
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,)
else:
print("test")
self.rope = Phi3SuScaledRotaryEmbedding(head_dim, args)`
Phi3SuScaledRotaryEmbedding : using your code above
`File ~/virtual_env_all/mlx_env/lib/python3.9/site-packages/mlx_lm/models/phi3.py:176, in (.0)
173 assert self.vocab_size > 0
174 self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
175 self.layers = [
--> 176 TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
177 ]
178 self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
File ~/virtual_env_all/mlx_env/lib/python3.9/site-packages/mlx_lm/models/phi3.py:146, in TransformerBlock.init(self, args)
144 self.num_attention_heads = args.num_attention_heads
145 self.hidden_size = args.hidden_size
--> 146 self.self_attn = Attention(args)
147 self.mlp = MLP(args.hidden_size, args.intermediate_size)
148 self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
File ~/virtual_env_all/mlx_env/lib/python3.9/site-packages/mlx_lm/models/phi3.py:85, in Attention.init(self, args)
83 else:
84 print("deepak")
---> 85 self.rope = Phi3SuScaledRotaryEmbedding(head_dim, args)
File ~/virtual_env_all/mlx_env/lib/python3.9/site-packages/mlx_lm/models/phi3.py:43, in Phi3SuScaledRotaryEmbedding.init(self, dim, config)
41 self.dim = dim
42 self.base = config.rope_theta
---> 43 self.short_factor = config.rope_scaling["short_factor"]
44 self.long_factor = config.rope_scaling["long_factor"]
45 self.original_max_position_embeddings = config.original_max_position_embeddings
TypeError: 'NoneType' object is not subscriptable`
@mustangs0786, it turns out that incorporating su-RoPE into mlx-lm required a bit more work than initially expected. I've just submitted a Pull Request with a modified implementation that seems to work well for phi-3-mini-128k: #813
though I'm hopeful our forth-coming fused attention will help there.
I was also thinking along the lines of having various attention implementations like fused attention etc... If this is already in works, can you link me to it?
or
suggest anything specific in this direction if its required?
thanks. cc @awni
We are already working on fused attention. What other variations did you have in mind?
Feel free to teach me here, not an expert at all.
-
I might have wrote in the wrong repo. I only see MHA MultiHeadAttention in the
mlx-explore/mlx
repo and thought we should haveMultiQueryAttention
,GroupedQueryAttention
as well. Deepseekv2 also introducedMultiLatentHeadAttention
i suppose. -
as I found fused attention is probably this and you're working on cuda/triton implementation?
-
also, offtopic and maybe a dumb Q, but I see we can train decoder model in mlx like here, so maybe extending with some architecture changes, we can train Llama style natively on mlx right? it would a great addition to examples because currently we have inference & Lora mostly.
Our fused attention support MQA and GQA as well.
you're working on cuda/triton implementation
-
No CUDA backend for MLX, everything is for Apple silicon
-
We have a transformer LM training example in this repo.
can you point me to the attention implementations in your work on fused attention please.. interested to dive and potentially help. cc @awni