JasonGross / guarantees-based-mechanistic-interpretability

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

model checkpointing doesn't work

JasonGross opened this issue · comments

I wrote some code to enable model checkpointing at

# Set up model checkpointing
# TODO(Euan or Jason, low-ish priority): fix model checkpointing, it doesn't seem to work
if config.checkpoint_every is not None:
if config.checkpoint_every[1] == "epochs":
checkpoint_callback = ModelCheckpoint(
dirpath=model_ckpt_dir_path,
filename=run_name + "-{epoch}-{step}",
every_n_epochs=1,
save_top_k=-1, # Set to -1 to save all checkpoints
)
elif config.checkpoint_every[1] == "steps":
checkpoint_callback = ModelCheckpoint(
dirpath=model_ckpt_dir_path,
filename=run_name + "-{epoch}-{step}",
every_n_train_steps=config.checkpoint_every[0],
save_top_k=-1, # Set to -1 to save all checkpoints
)
else:
checkpoint_callback = None
# Fit model
train_metric_callback = MetricsCallback()
callbacks = [train_metric_callback, RichProgressBar()]
if checkpoint_callback is not None:
callbacks.append(checkpoint_callback)

But it doesn't seem to work.