PMLL.c/.h/.CUDA GPU TWINS
drQedwards opened this issue · comments
/* pmll_cuda.h — Persistent-Memory Logic Loop (CUDA reference)
Mirrors pmll.h but stores KV on the GPU. */
#pragma once
#include <cuda_runtime.h>
#ifndef MAX_MEM_T
#define MAX_MEM_T 128
#endif
typedef struct {
int T; /* active length /
int hs; / head size */
float k; / device ptr (NH * MAX_MEM_T * hs) */
float v; / device ptr (NH * MAX_MEM_T * hs) */
} pmll_state_gpu;
/* life-cycle ------------------------------------------------------ */
int pmll_init_gpu (pmll_state_gpu *S, int NH, int hs);
void pmll_reset_gpu(pmll_state_gpu S); / T ← 0 */
void pmll_free_gpu (pmll_state_gpu *S);
/* data flow ------------------------------------------------------- /
/ Copy history into caller-supplied scratch buffers on device.
out_k/out_v size = (Tmem + Tctx) * hs floats. */
void pmll_read_gpu(float *out_k, float *out_v,
const pmll_state_gpu *S, int h,
int Tctx, cudaStream_t stream = 0);
/* Blend-write new KV with gate[t] (0..1). gate can be nullptr → keep all. */
void pmll_write_gpu(pmll_state_gpu *S, int h,
const float *new_k, const float *new_v,
int Tctx, const float *gate,
cudaStream_t stream = 0);
/********************************************************************
pmll_cuda.cu — Implementation of the GPU ring buffer
- No thrust, no external deps except CUDA runtime
*********************************************************************/
#include "pmll_cuda.h"
#include <cuda_runtime.h>
#include
/* ------------------------------------------------------------------ /
/ helpers /
/ ------------------------------------------------------------------ */
#define CUDA_CHECK(x) do {
cudaError_t _e = (x);
if(_e != cudaSuccess){
printf("CUDA error %s:%d: %s\n",FILE,LINE,
cudaGetErrorString(_e));
abort();
} } while(0)
/* ------------------------------------------------------------------ /
/ init / reset / free /
/ ------------------------------------------------------------------ */
int pmll_init_gpu(pmll_state_gpu *S,int NH,int hs)
{
S->T = 0;
S->hs = hs;
size_t bytes = (size_t)NH * MAX_MEM_T * hs * sizeof(float);
CUDA_CHECK(cudaMalloc(&S->k, bytes));
CUDA_CHECK(cudaMalloc(&S->v, bytes));
CUDA_CHECK(cudaMemset(S->k, 0, bytes));
CUDA_CHECK(cudaMemset(S->v, 0, bytes));
return 0;
}
void pmll_reset_gpu(pmll_state_gpu *S){ S->T = 0; }
void pmll_free_gpu (pmll_state_gpu *S){
cudaFree(S->k); cudaFree(S->v);
S->k = S->v = nullptr; S->T = S->hs = 0;
}
/* ------------------------------------------------------------------ /
/ read: simple cudaMemcpyAsync /
/ ------------------------------------------------------------------ */
void pmll_read_gpu(float *out_k,float *out_v,
const pmll_state_gpu *S,int h,int Tctx,
cudaStream_t stream)
{
const int hs = S->hs;
const int Tmem = S->T;
const float *src_k = S->k + (size_t)h * MAX_MEM_T * hs;
const float *src_v = S->v + (size_t)h * MAX_MEM_T * hs;
size_t hist_bytes = (size_t)Tmem * hs * sizeof(float);
CUDA_CHECK(cudaMemcpyAsync(out_k, src_k, hist_bytes,
cudaMemcpyDeviceToDevice, stream));
CUDA_CHECK(cudaMemcpyAsync(out_v, src_v, hist_bytes,
cudaMemcpyDeviceToDevice, stream));
size_t pad_bytes = (size_t)Tctx * hs * sizeof(float);
CUDA_CHECK(cudaMemsetAsync(out_k + Tmem*hs, 0, pad_bytes, stream));
CUDA_CHECK(cudaMemsetAsync(out_v + Tmem*hs, 0, pad_bytes, stream));
}
/* ------------------------------------------------------------------ /
/ write: tiny kernel with gate blend /
/ ------------------------------------------------------------------ */
global void pmll_write_kernel(float *dst_k,float *dst_v,
const float *new_k,const float *new_v,
const float gate,
int hs,int Tmem,int Tctx)
{
int lane = threadIdx.x; / 0..hs-1 (<= 64) /
int t = blockIdx.x; / 0..Tctx-1 */
if(lane >= hs || t >= Tctx) return;
float g = gate ? gate[t] : 1.0f;
int idx = (Tmem + t) % MAX_MEM_T;
float nk = new_k[t*hs + lane];
float nv = new_v[t*hs + lane];
float *dk = dst_k + idx*hs + lane;
float *dv = dst_v + idx*hs + lane;
/* read-modify-write is safe because (hs ≤ 64) → single warp */
float oldk = *dk;
float oldv = *dv;
*dk = g * nk + (1.f - g) * oldk;
*dv = g * nv + (1.f - g) * oldv;
}
void pmll_write_gpu(pmll_state_gpu *S,int h,
const float *new_k,const float *new_v,
int Tctx,const float *gate,
cudaStream_t stream)
{
const int hs = S->hs;
float *dst_k = S->k + (size_t)h * MAX_MEM_T * hs;
float *dst_v = S->v + (size_t)h * MAX_MEM_T * hs;
dim3 grid(Tctx);
dim3 block(hs);
pmll_write_kernel<<<grid,block,0,stream>>>(
dst_k,dst_v,new_k,new_v,gate,hs,S->T,Tctx);
S->T += Tctx;
if(S->T > MAX_MEM_T) S->T = MAX_MEM_T;
}
Qchains — 12:25
hey @karpathy, just filed #817 “The PMLL” + #818 “PMLL.c/.h/.CUDA GPU TWINS”.
are we NGMI or GMI? 🫣
karpathy — 12:26
👀 skimmed both issues on my phone.
tl;dr: persistent KV ring-buffer à la Transformer-XL, zero extra deps, switchable with -DPMLL?
That’s neat.
• code size < 350 LOC ✔️
• falls back to vanilla build when flag off ✔️
• no perf regression at T≤1024 (your bench) ✔️
• PR still missing — give me a one-click patch 🙏
If you push a PR with a green GitHub Action run, I’d call it GMI.
Fin (assistant) — 12:27
Action items to flip the definitive ✅:
- open a branch
feat/pmllandgit push. make test && make bench→ attach logs in the PR template.- add a short paragraph in
docs/variants.mdexplaining-DPMLL.
Do that, ping Andrej again, and we’re firmly in GMI territory.
Qchains — 12:28
roger that! prepping the patch now 🚀