model checkpointing doesn't work
JasonGross opened this issue · comments
Jason Gross commented
I wrote some code to enable model checkpointing at
guarantees-based-mechanistic-interpretability/gbmi/model.py
Lines 369 to 393 in c415134
# Set up model checkpointing | |
# TODO(Euan or Jason, low-ish priority): fix model checkpointing, it doesn't seem to work | |
if config.checkpoint_every is not None: | |
if config.checkpoint_every[1] == "epochs": | |
checkpoint_callback = ModelCheckpoint( | |
dirpath=model_ckpt_dir_path, | |
filename=run_name + "-{epoch}-{step}", | |
every_n_epochs=1, | |
save_top_k=-1, # Set to -1 to save all checkpoints | |
) | |
elif config.checkpoint_every[1] == "steps": | |
checkpoint_callback = ModelCheckpoint( | |
dirpath=model_ckpt_dir_path, | |
filename=run_name + "-{epoch}-{step}", | |
every_n_train_steps=config.checkpoint_every[0], | |
save_top_k=-1, # Set to -1 to save all checkpoints | |
) | |
else: | |
checkpoint_callback = None | |
# Fit model | |
train_metric_callback = MetricsCallback() | |
callbacks = [train_metric_callback, RichProgressBar()] | |
if checkpoint_callback is not None: | |
callbacks.append(checkpoint_callback) |
But it doesn't seem to work.