Minimize how often we load args.resume
achalddave opened this issue · comments
Currently, we load args.resume potentially up to 3 times. This can be pretty slow for big models, and we should avoid re-loading it in these spots:
Lines 110 to 156 in 97d0a4a
def load_model(args, model): | |
checkpoint = pt_load(args.resume, map_location="cpu") | |
if "epoch" in checkpoint: | |
# resuming a train checkpoint w/ epoch and optimizer state | |
start_epoch = checkpoint["epoch"] | |
sd = checkpoint["state_dict"] | |
if next(iter(sd.items()))[0].startswith("module"): | |
sd = {k[len("module.") :]: v for k, v in sd.items()} | |
model.load_state_dict(sd) | |
logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})") | |
else: | |
# loading a bare (model only) checkpoint for fine-tune or evaluation | |
model.load_state_dict(checkpoint) | |
logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") | |
return start_epoch | |
def load_optimizer(args, model, optimizer, scaler): | |
potential_checkpoint = args.resume.replace("epoch_", "optimizer_") | |
if check_exists(potential_checkpoint): | |
checkpoint = pt_load(potential_checkpoint, map_location="cpu") | |
else: | |
checkpoint = pt_load(args.resume, map_location="cpu") | |
if "optimizer" in checkpoint: | |
if optimizer is not None: | |
osd = checkpoint["optimizer"] | |
if args.fsdp: | |
osd = FSDP.optim_state_dict_to_load( | |
model=model, optim=optimizer, optim_state_dict=osd | |
) | |
optimizer.load_state_dict(osd) | |
logging.info(f"=> resuming optimizer") | |
if scaler is not None and "scaler" in checkpoint: | |
scaler.load_state_dict(checkpoint["scaler"]) | |
else: | |
logging.info(f"=> WARNING: not resuming optimizer.") | |
def load_data_chunks(args): | |
checkpoint = pt_load(args.resume, map_location="cpu") | |
if "next_chunk" in checkpoint and "samples_seen" in checkpoint: | |
return checkpoint["next_chunk"], checkpoint["samples_seen"] | |
else: | |
logging.info( | |
f"=> WARNING: tried to resume a checkpoint without data chunk info. Assuming next_chunk = 0." | |
) | |
return 0, 0 |
In the current iteration, I believe our main offender is: https://github.com/mlfoundations/open_lm/blob/main/open_lm/main.py#L142-L151.