mlfoundations / open_lm

A repository for research on medium sized language models.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Minimize how often we load args.resume

achalddave opened this issue · comments

Currently, we load args.resume potentially up to 3 times. This can be pretty slow for big models, and we should avoid re-loading it in these spots:

open_lm/open_lm/main.py

Lines 110 to 156 in 97d0a4a

def load_model(args, model):
checkpoint = pt_load(args.resume, map_location="cpu")
if "epoch" in checkpoint:
# resuming a train checkpoint w/ epoch and optimizer state
start_epoch = checkpoint["epoch"]
sd = checkpoint["state_dict"]
if next(iter(sd.items()))[0].startswith("module"):
sd = {k[len("module.") :]: v for k, v in sd.items()}
model.load_state_dict(sd)
logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})")
else:
# loading a bare (model only) checkpoint for fine-tune or evaluation
model.load_state_dict(checkpoint)
logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})")
return start_epoch
def load_optimizer(args, model, optimizer, scaler):
potential_checkpoint = args.resume.replace("epoch_", "optimizer_")
if check_exists(potential_checkpoint):
checkpoint = pt_load(potential_checkpoint, map_location="cpu")
else:
checkpoint = pt_load(args.resume, map_location="cpu")
if "optimizer" in checkpoint:
if optimizer is not None:
osd = checkpoint["optimizer"]
if args.fsdp:
osd = FSDP.optim_state_dict_to_load(
model=model, optim=optimizer, optim_state_dict=osd
)
optimizer.load_state_dict(osd)
logging.info(f"=> resuming optimizer")
if scaler is not None and "scaler" in checkpoint:
scaler.load_state_dict(checkpoint["scaler"])
else:
logging.info(f"=> WARNING: not resuming optimizer.")
def load_data_chunks(args):
checkpoint = pt_load(args.resume, map_location="cpu")
if "next_chunk" in checkpoint and "samples_seen" in checkpoint:
return checkpoint["next_chunk"], checkpoint["samples_seen"]
else:
logging.info(
f"=> WARNING: tried to resume a checkpoint without data chunk info. Assuming next_chunk = 0."
)
return 0, 0

In the current iteration, I believe our main offender is: https://github.com/mlfoundations/open_lm/blob/main/open_lm/main.py#L142-L151.