pytorch / torchtitan

A native PyTorch Library for large model training

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

reload existing llama checkpoints

tianyu-l opened this issue · comments

Is this issue related to loading pretrained Llama2/Llama3 weights and using them as checkpoint?

I was going to start a separate issue asking for some docs that explain how to convert pretrained weights from HF to torchtitan in order to do continued pretraining. Is that already possible or on the roadmap?

DCP has the format util to help the conversion. However, HF conversion should not live in PyTorch code base.

@lessw2020 will connect with HF to see if they can support weights conversion from HF to pytorch. After that, we may import that in the code or update the tutorial.

I have a straightforward script for converting from HF to a DCP checkpoint, if that helps. Mostly the script already exists in gpt-fast.

@rlrs Thanks, pls feel free to share it here!

As far as we know, HF is also working on such a script to convert from HF to DCP. As discussed in #335, we should include a script to convert from llama raw weights into DCP (similar to the one here), and it probably should still sit in pytorch/pytorch.

Alright so this is the script I'm using for HF->DCP. It uses the safetensors weights (but can easily be converted to load a torch.save instead), which only exist in https://huggingface.co/meta-llama/Meta-Llama-3-8B/tree/main in the root, and not under original/. So as we discussed in #335, some of the weights are permuted compared to the original.
I've been using it to just create a step-0 checkpoint that torchtitan is already set up to start from.

import json
import re
import sys
from pathlib import Path
from safetensors import safe_open
import torch.distributed.checkpoint as DCP

import torch

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from maester.models import models_config


@torch.inference_mode()
def convert_hf_checkpoint(
    *,
    checkpoint_dir: Path,
    output_dir: Path,
) -> None:
    # Load the json file containing weight mapping
    model_map_json = checkpoint_dir / "model.safetensors.index.json"

    assert model_map_json.is_file()

    with open(model_map_json, 'r') as json_map:
        bin_index = json.load(json_map)

    weight_map = {
        "model.embed_tokens.weight": "tok_embeddings.weight",
        "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
        "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
        "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
        "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
        'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
        'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
        "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
        "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
        "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
        "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
        "model.norm.weight": "norm.weight",
        "lm_head.weight": "output.weight",
    }
    bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}

    merged_result = {}
    for file in sorted(bin_files):
        with safe_open(file, framework="pt", device="cpu") as f:
            for k in f.keys():
                merged_result[k] = f.get_tensor(k)
    final_result = {}
    
    for key, value in merged_result.items():
        if "layers" in key:
            abstract_key = re.sub(r'(\d+)', '{}', key)
            layer_num = re.search(r'\d+', key).group(0)
            new_key = weight_map[abstract_key]
            if new_key is None:
                continue
            new_key = new_key.format(layer_num)
        else:
            new_key = weight_map[key]

        final_result[new_key] = value

    output_dir.mkdir(parents=True, exist_ok=True)
    storage_writer = DCP.filesystem.FileSystemWriter(output_dir)
    DCP.save({"model": final_result}, 
             storage_writer=storage_writer)

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.')
    parser.add_argument('--checkpoint', type=Path, required=True)
    parser.add_argument('--output', type=Path, required=True)

    args = parser.parse_args()
    convert_hf_checkpoint(
        checkpoint_dir=args.checkpoint,
        output_dir=args.output,
    )

Alright so this is the script I'm using for HF->DCP. It uses the safetensors weights (but can easily be converted to load a torch.save instead), which only exist in https://huggingface.co/meta-llama/Meta-Llama-3-8B/tree/main in the root, and not under original/. So as we discussed in #335, some of the weights are permuted compared to the original. I've been using it to just create a step-0 checkpoint that torchtitan is already set up to start from.

import json
import re
import sys
from pathlib import Path
from safetensors import safe_open
import torch.distributed.checkpoint as DCP

import torch

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from maester.models import models_config


@torch.inference_mode()
def convert_hf_checkpoint(
    *,
    checkpoint_dir: Path,
    output_dir: Path,
) -> None:
    # Load the json file containing weight mapping
    model_map_json = checkpoint_dir / "model.safetensors.index.json"

    assert model_map_json.is_file()

    with open(model_map_json, 'r') as json_map:
        bin_index = json.load(json_map)

    weight_map = {
        "model.embed_tokens.weight": "tok_embeddings.weight",
        "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
        "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
        "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
        "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
        'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
        'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
        "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
        "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
        "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
        "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
        "model.norm.weight": "norm.weight",
        "lm_head.weight": "output.weight",
    }
    bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}

    merged_result = {}
    for file in sorted(bin_files):
        with safe_open(file, framework="pt", device="cpu") as f:
            for k in f.keys():
                merged_result[k] = f.get_tensor(k)
    final_result = {}
    
    for key, value in merged_result.items():
        if "layers" in key:
            abstract_key = re.sub(r'(\d+)', '{}', key)
            layer_num = re.search(r'\d+', key).group(0)
            new_key = weight_map[abstract_key]
            if new_key is None:
                continue
            new_key = new_key.format(layer_num)
        else:
            new_key = weight_map[key]

        final_result[new_key] = value

    output_dir.mkdir(parents=True, exist_ok=True)
    storage_writer = DCP.filesystem.FileSystemWriter(output_dir)
    DCP.save({"model": final_result}, 
             storage_writer=storage_writer)

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.')
    parser.add_argument('--checkpoint', type=Path, required=True)
    parser.add_argument('--output', type=Path, required=True)

    args = parser.parse_args()
    convert_hf_checkpoint(
        checkpoint_dir=args.checkpoint,
        output_dir=args.output,
    )

Thanks for sharing.

Is there a conversion in the other direction? Meaning converting a dcp checkpoint to an HF model? I found a util dcp_to_torch_save but am not sure how to go from there to a HF model.

@bkchang From HF website, there's a script to convert llama weights to HF format.

@tianyu-l Thanks for the comment. Unfortunately, that script is for converting a llama model in the format it was first uploaded by the llama team. The script thus requires input files like params.json and tokenizer.model, and torchtitan doesn't generate these. What I would like to know is how to go from torchtitan output weights to a HF model. Thank you.