llama speed tracking issue
chaosagent opened this issue · comments
- master: 130 tok/s green, 100 tok/s red on 7B, 4 gpus
- master: roughly 15 tok/s green/red 70B
memory layouts (2 all reduce for 4gpu, 2 allreduce + 2 "smear" for 6gpu): x(axis=None) -> w1/w3(axis=0) -> x(axis=-1) -> w2(axis=-1) -> allreduce -> x(axis=None) -> (wq, wk, wv)(axis=0) -> x(axis=kv head dimension) -> attention(axis=kv head dimension) -> x(axis=-1) -> wo(axis=-1) -> allreduce -> next layer
llama 3 OOM: llama 3 has a very big embedding matrix (1GB+). currently code loads weights onto 1 gpu then the rest of them, and embeddings are loaded last, so it will OOM on that loading gpu. solve by loading large tensors first. also can use MLB(weight shards, axis=weight shard axis) to concat weights without realizing on one gpu.
llama 3 OOM with large context: solve with kv shard.
kv shard layout: we have 8 kv heads but 6 gpus. we can shard (2, 2, 1, 1, 1, 1), but we need a way to move to/from this layout from the evenly sharded weights. propose "reshard" api: allow shard on Tensor(mlb), and def shard(devices, axis, splits).
slow kernels: round shards to 32. hard to do generally because you don't want to do this for BS, but you do want it for llm tensor parallel. or we could get the search to learn how to padto+local.
PoC + api proposal: master...chaosagent:tinygrad:llama3_70b 19 tok/s 70b, 140 tok/s 7b on green, fits llama 3 70b w/ 8192 context on amd, but not NV (memory fragmentation?). don't know how to support mlb.pad yet.
jit tune -> 21 tok/s. cpu time spent mostly updating var vals and launch bounds.
currently the api looks like this (notice the splits argument):
# shard
R = 64
n_kv_heads = params["args"].get("n_kv_heads", params["args"]["n_heads"])
if isinstance(device, tuple):
for k,v in nn.state.get_state_dict(model).items():
if 'scale' in k: v.shard_(device, axis=None) # from quantized
elif '.attention.wo' in k: v.shard_(device, axis=-1, splits=v.shape[-1] // n_kv_heads if "70" not in model_size else R)
elif '.attention.' in k: v.shard_(device, axis=0, splits=v.shape[0] // n_kv_heads if "70" not in model_size else R)
elif '.feed_forward.w1.' in k: v.shard_(device, axis=0, splits=R)
elif '.feed_forward.w3.' in k: v.shard_(device, axis=0, splits=R)
elif '.feed_forward.' in k: v.shard_(device, axis=-1, splits=R)
elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0, splits=R)
elif 'output.weight' in k: v.shard_(device, axis=-1, splits=R)
#elif k.endswith('.weight'): v.shard_(device, axis=-1)
#elif 'norm.' in k: v.shard_(device, axis=-1)
else: v.shard_(device, axis=None)
Attention.__call__:
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.shard(xq.device, -1, splits=self.n_rep*self.head_dim).reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
xk = xk.shard(xq.device, -1, splits=self.head_dim).reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
xv = xv.shard(xq.device, -1, splits=self.head_dim).reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis.cast(xq.dtype))
bsz, seqlen, _, _ = xq.shape
...
attn = attn.reshape(bsz, seqlen, -1).shard(attn.device, -1, splits=self.wo.weight.lazydata.splits)
return self.wo(attn)