[nanoChatGPT] Should we recompute attention_mask instead of storing?
tcbegley opened this issue · comments
It can be recovered by doing something like
attention_mask = (input_ids != pad_token_id).to(torch.int64)
so we could potentially reduce memory footprint in this way at the cost of some extra computation each iteration. We should benchmark and see what's best.