AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Support target masking (aka loss masking or label masking) for SFT datasets

jmschndev opened this issue · comments

Right now, data loading and loss computation assume one is only doing LM pretraining, but it'd be useful to support packed SFT style datasets (i.e. datasets with cleanly delineated prompt/completion pairs, perhaps even a system prompt) and their corresponding masking.

I.e., the masks allow the attention module to reference the prompts/prefix, but only completions/targets' gradients are propogated.