pytorch / torchtitan

A native PyTorch Library for large model training

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

`freqs_cis` in llama model should be a non-persistent buffer

tianyu-l opened this issue · comments

Currently it is registered as a persistent buffer, because of two reasons, copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py#L355

# TODO persistent should be set to false, since this buffer can be recomputed.
# however, we set it to true for 2 reasons.  (1) due to pytorch/pytorch#123411,
# compile or pipeline-tracer will not correctly handle non-persistent buffers,
# so we need to fix that.  (2) if we initialize pipeline-parallel models from
# a seed checkpoint rather than calling init_weights, we need freqs_cis to be
# initialized by the checkpoint, or we need to add a separate initializer for
# just the non-persistent buffers that is called after loading checkpoints.

This issue is to track the progress on it. If (1) is fixed, and (2) seems the best solution, we can close this issue.