ml-explore / mlx-examples

Examples in the MLX framework

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Phi-3 128K Context Variants' `su` RoPE Scaling

JosefAlbers opened this issue · comments

Hi, I've noticed that the current implementation of the Phi-3 model in the mlx-lm repository seems to only support linear RoPE scaling. This limitation I think prevents the longer context variants of the Phi-3 model (including the Phi-3-vision model) from functioning correctly.

In my work porting the Phi-3-vision model to MLX, I've written code to implement "su"-scaled RoPE. I'd be happy to create a pull request (PR) to add this functionality to the phi3.py file, allowing the longer context Phi-3 models to work as intended.

class Phi3SuScaledRotaryEmbedding(nn.Module):
    def __init__(self, dim, config):
        self.dim = dim
        self.base = config.rope_theta 
        self.short_factor = config.rope_scaling["short_factor"]
        self.long_factor = config.rope_scaling["long_factor"]
        self.original_max_position_embeddings = config.original_max_position_embeddings
        self.scaling_factor = math.sqrt(1 + math.log(config.max_position_embeddings / config.original_max_position_embeddings) / math.log(config.original_max_position_embeddings))
        self.inv_freq = None

    def __call__(self, position_ids):
        seq_len = position_ids.max() + 1
        ext_factors = mx.array(self.long_factor, dtype=mx.float32) if seq_len > self.original_max_position_embeddings else mx.array(self.short_factor, dtype=mx.float32)
        inv_freq_shape = mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim
        self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
        inv_freq_expanded = mx.repeat(self.inv_freq[None, :, None], position_ids.shape[0], axis=0)
        position_ids_expanded = mx.array(position_ids, dtype=mx.float32)[:, None, :]
        freqs = mx.matmul(inv_freq_expanded, position_ids_expanded).transpose(0, 2, 1)  
        emb = mx.concatenate([freqs, freqs], axis=-1)  
        cos = mx.cos(emb) * self.scaling_factor
        sin = mx.sin(emb) * self.scaling_factor
        return cos, sin 

Benefits of Adding Su-scaled RoPE:

  • Enable Long Context Variants: This change would immediately make the longer context versions of the Phi-3 model usable within MLX.
  • Improved Performance: Su-scaled RoPE has been shown to generally improve performance over linear scaling, especially for longer sequences.
  • Alignment with Original Model: Ensures the MLX implementation accurately reflects the design of the original Phi-3 models.

Please let me know if you'd be interested in a PR to add this functionality. I'm happy to discuss this further and provide any necessary details.

Can you please help where i can fit in the libarary code. i have seen your code, but if i want to use original library and fit this code. can you please help thanks...

Sorry for the delayed response. I definitely think we should add this to the model files as they are presumably incorrect now for long context. Would you mind sending a PR @JosefAlbers ?

At the moment running very long context might hit memory limitations, though I'm hopeful our forth-coming fused attention will help there.

In my work porting the Phi-3-vision model to MLX

PS that is a very cool project! Is it functional? Do you mind if I share it more broadly?

Can you please help where i can fit in the libarary code. i have seen your code, but if i want to use original library and fit this code. can you please help thanks...

@mustangs0786, I'm currently working on integrating su-RoPE scaling directly into the Phi-3 model and plan to submit a pull request (PR) soon. In the meantime, you can try this temporary workaround within the Attention module's init method:

class Attention(nn.Module):
    def __init__(self, args):
        # ... 
        if args.rope_scaling is not None:
            if args.rope_scaling["type"] == "linear":
                rope_scale = 1 / args.rope_scaling["factor"]
                self.rope = nn.RoPE(
                    args.head_dim,
                    traditional=args.rope_traditional,
                    base=args.rope_theta,
                    scale=rope_scale,
                )
            elif args.rope_scaling["type"] == "su":
                self.rope = Phi3SuScaledRotaryEmbedding(args.head_dim, args) 

Sorry for the delayed response. I definitely think we should add this to the model files as they are presumably incorrect now for long context. Would you mind sending a PR @JosefAlbers ?

@awni My pleasure, I'll begin working on the PR shortly.

At the moment running very long context might hit memory limitations, though I'm hopeful our forth-coming fused attention will help there.

That would be fantastic.

PS that is a very cool project!

Wow, thank you!

Is it functional?

The project is at this point functional in several key tasks, including image captioning, batched generation, LoRA training, and model/cache quantization. You can find more details in my README.md.

Do you mind if I share it more broadly?

That would be very kind of you, thank you so much!

Oh, and the su-RoPE is a bit different from how it was when I originally posted it last week. It's now as following:

class Phi3SuScaledRotaryEmbedding:
    def __init__(self, dim, config, **kwargs):
        self.inv_freq_short = 1.0 / (mx.array(config.rope_scaling["short_factor"], dtype=mx.float32) * config.rope_theta**(mx.arange(0, dim, 2, dtype=mx.float32) / dim))
        self.inv_freq_long = 1.0 / (mx.array(config.rope_scaling["long_factor"], dtype=mx.float32) * config.rope_theta**(mx.arange(0, dim, 2, dtype=mx.float32) / dim))
        self.original_max_position_embeddings = config.original_max_position_embeddings
        self.scaling_factor = math.sqrt(1 + math.log(config.max_position_embeddings / config.original_max_position_embeddings) / math.log(config.original_max_position_embeddings))

    def _get_cos_sin(self, offset, L, pids):
        def _get_pids(offset, L, pids):
            if offset < 1:
                return pids
            return pids[:, -1][:, None] + offset - pids.shape[1] + 2 + mx.arange(L)[None, :]
        position_ids = mx.arange(offset, offset+L, dtype=mx.float32)[None] if pids is None else _get_pids(offset, L, pids)
        inv_freq = self.inv_freq_long if position_ids.max()+1 > self.original_max_position_embeddings else self.inv_freq_short
        inv_freq_expanded = mx.repeat(inv_freq[None, :, None], position_ids.shape[0], axis=0)
        position_ids_expanded = position_ids[:, None, :]
        freqs = (inv_freq_expanded @ position_ids_expanded).transpose(0, 2, 1)  
        emb = mx.concatenate([freqs, freqs], axis=-1)  
        cos = mx.cos(emb) * self.scaling_factor
        sin = mx.sin(emb) * self.scaling_factor
        return mx.expand_dims(cos, axis=1), mx.expand_dims(sin, axis=1) 

    def __call__(self, q, k=None, offset=0, pids=None):
        def _rotate_half(x):
            midpoint = x.shape[-1] // 2  
            x1, x2 = x[..., :midpoint], x[..., midpoint:]  
            return mx.concatenate([-x2, x1], axis = -1) 
        cos, sin = self._get_cos_sin(offset, q.shape[2], pids)
        return (q * cos) + (_rotate_half(q) * sin) if k is None else (q * cos) + (_rotate_half(q) * sin), (k * cos) + (_rotate_half(k) * sin)

@JosefAlbers Hi i tried implementing,
`class Attention(nn.Module):
def init(self, args: ModelArgs):
super().init()

    dim = args.hidden_size
    self.n_heads = n_heads = args.num_attention_heads
    self.n_kv_heads = n_kv_heads = args.num_key_value_heads

    head_dim = args.hidden_size // n_heads
    self.scale = head_dim**-0.5

    op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim)
    self.qkv_proj = nn.Linear(dim, op_size, bias=False)
    self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
    if args.rope_scaling is not None and args.rope_scaling["type"] == "linear":
        rope_scale = args.rope_scaling["factor"]
        self.rope = nn.RoPE(
        head_dim,
        traditional=args.rope_traditional,
        base=args.rope_theta,
        scale=rope_scale,)
    else:
        print("test")
        self.rope = Phi3SuScaledRotaryEmbedding(head_dim, args)`
        
   Phi3SuScaledRotaryEmbedding : using your code above

`File ~/virtual_env_all/mlx_env/lib/python3.9/site-packages/mlx_lm/models/phi3.py:176, in (.0)
173 assert self.vocab_size > 0
174 self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
175 self.layers = [
--> 176 TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
177 ]
178 self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)

File ~/virtual_env_all/mlx_env/lib/python3.9/site-packages/mlx_lm/models/phi3.py:146, in TransformerBlock.init(self, args)
144 self.num_attention_heads = args.num_attention_heads
145 self.hidden_size = args.hidden_size
--> 146 self.self_attn = Attention(args)
147 self.mlp = MLP(args.hidden_size, args.intermediate_size)
148 self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)

File ~/virtual_env_all/mlx_env/lib/python3.9/site-packages/mlx_lm/models/phi3.py:85, in Attention.init(self, args)
83 else:
84 print("deepak")
---> 85 self.rope = Phi3SuScaledRotaryEmbedding(head_dim, args)

File ~/virtual_env_all/mlx_env/lib/python3.9/site-packages/mlx_lm/models/phi3.py:43, in Phi3SuScaledRotaryEmbedding.init(self, dim, config)
41 self.dim = dim
42 self.base = config.rope_theta
---> 43 self.short_factor = config.rope_scaling["short_factor"]
44 self.long_factor = config.rope_scaling["long_factor"]
45 self.original_max_position_embeddings = config.original_max_position_embeddings

TypeError: 'NoneType' object is not subscriptable`

@mustangs0786, it turns out that incorporating su-RoPE into mlx-lm required a bit more work than initially expected. I've just submitted a Pull Request with a modified implementation that seems to work well for phi-3-mini-128k: #813

though I'm hopeful our forth-coming fused attention will help there.

I was also thinking along the lines of having various attention implementations like fused attention etc... If this is already in works, can you link me to it?
or
suggest anything specific in this direction if its required?
thanks. cc @awni

We are already working on fused attention. What other variations did you have in mind?

Feel free to teach me here, not an expert at all.

  • I might have wrote in the wrong repo. I only see MHA MultiHeadAttention in the mlx-explore/mlx repo and thought we should have MultiQueryAttention, GroupedQueryAttention as well. Deepseekv2 also introduced MultiLatentHeadAttention i suppose.

  • as I found fused attention is probably this and you're working on cuda/triton implementation?

  • also, offtopic and maybe a dumb Q, but I see we can train decoder model in mlx like here, so maybe extending with some architecture changes, we can train Llama style natively on mlx right? it would a great addition to examples because currently we have inference & Lora mostly.

Our fused attention support MQA and GQA as well.

you're working on cuda/triton implementation

can you point me to the attention implementations in your work on fused attention please.. interested to dive and potentially help. cc @awni