Default setting for model checkpointer seems to break training
chaitjo opened this issue · comments
After making the fix in #17 and getting the training script to start running with the default configs, I'm encountering the following error wherein the code breaks after completing one training and one validation epoch. The error message suggests that the metric valid/non_coil_percent
which is used for saving checkpoints was not logged/saved/computed.
Please let me know if there's a quick work around.
[2024-02-08 19:30:39,372][__main__][INFO] - Checkpoints saved to ckpt/se3-fm/baseline/2024-02-08_19-30-26
[2024-02-08 19:30:39,419][__main__][INFO] - Using devices: [0]
[2024-02-08 19:30:39,947][torch.distributed.distributed_c10d][INFO] - Added key: store_based_barrier_key:1 to store for rank: 0
[2024-02-08 19:30:39,948][torch.distributed.distributed_c10d][INFO] - Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 1 nodes.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------
[2024-02-08 19:30:41,168][data.pdb_dataloader][INFO] - Training: 3938 examples
[2024-02-08 19:30:41,207][data.pdb_dataloader][INFO] - Validation: 40 examples with lengths [ 20 38 53 68 83 98 113 128]
Sanity Checking DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
------------------------------------
0 | model | FlowModel | 16.7 M
------------------------------------
16.7 M Trainable params
0 Non-trainable params
16.7 M Total params
[2024-02-08 19:30:53,758][data.pdb_dataloader][INFO] - Created dataloader rank 1 out of 1
Epoch 0: 0%| | 0/3938 [00:00<?, ?it/s][2024-02-08 19:30:56,139][data.so3_utils][INFO] - Data loaded from .cache/cache_igso3_s0.100-1.500-1000_l1000_o1000-3.npz
Epoch 0: 0%| | 2/3938 [00:02<1:19:48, 1.22s/it, v_num=2dvk, train/examples_per_second=149.0, train/loss=13.70][2024-02-08 19:30:57,066][torch.nn.parallel.distributed][INFO] - Reducer buckets have been rebuilt in this iteration.
Epoch 1: 100%|████████████████████████████████| 3938/3938 [21:38<00:00, 3.03it/s, v_num=2dvk, train/examples_per_second=245.0, train/loss=2.520]
Validation DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [03:20<00:00, 5.01s/it]
Error executing job with overrides: []
Traceback (most recent call last):
File "/home/ckj24/protein-frame-flow/experiments/train_se3_flows.py", line 97, in main
exp.train()
File "/home/ckj24/protein-frame-flow/experiments/train_se3_flows.py", line 72, in train
trainer.fit(
File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 532, in fit
call._call_and_handle_interrupt(
File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 42, in _call_and_handle_interrupt
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 93, in launch
return function(*args, **kwargs)
File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 571, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 980, in _run
results = self._run_stage()
File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1023, in _run_stage
self.fit_loop.run()
File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 202, in run
self.advance()
File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 355, in advance
self.epoch_loop.run(self._data_fetcher)
File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 134, in run
self.on_advance_end()
File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 249, in on_advance_end
self.val_loop.run()
File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py", line 181, in _decorator
return loop_run(self, *args, **kwargs)
File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 122, in run
return self.on_run_end()
File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 258, in on_run_end
self._on_evaluation_end()
File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 303, in _on_evaluation_end
call._call_callback_hooks(trainer, hook_name, *args, **kwargs)
File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 194, in _call_callback_hooks
fn(trainer, trainer.lightning_module, *args, **kwargs)
File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 311, in on_validation_end
self._save_topk_checkpoint(trainer, monitor_candidates)
File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 358, in _save_topk_checkpoint
raise MisconfigurationException(m)
lightning_fabric.utilities.exceptions.MisconfigurationException: `ModelCheckpoint(monitor='valid/non_coil_percent')` could not find the monitored key in the returned metrics: ['train/bb_atom_loss', 'train/trans_loss', 'train/dist_mat_loss', 'train/auxiliary_loss', 'train/rots_vf_loss', 'train/se3_vf_loss', 'train/t', 'train/bb_atom_loss t=[0.25,0.50)', 'train/trans_loss t=[0.25,0.50)', 'train/dist_mat_loss t=[0.25,0.50)', 'train/auxiliary_loss t=[0.25,0.50)', 'train/rots_vf_loss t=[0.25,0.50)', 'train/se3_vf_loss t=[0.25,0.50)', 'train/length', 'train/batch_size', 'train/examples_per_second', 'train/loss', 'train/bb_atom_loss t=[0.00,0.25)', 'train/bb_atom_loss t=[0.50,0.75)', 'train/bb_atom_loss t=[0.75,1.00)', 'train/trans_loss t=[0.00,0.25)', 'train/trans_loss t=[0.50,0.75)', 'train/trans_loss t=[0.75,1.00)', 'train/dist_mat_loss t=[0.00,0.25)', 'train/dist_mat_loss t=[0.50,0.75)', 'train/dist_mat_loss t=[0.75,1.00)', 'train/auxiliary_loss t=[0.00,0.25)', 'train/auxiliary_loss t=[0.50,0.75)', 'train/auxiliary_loss t=[0.75,1.00)', 'train/rots_vf_loss t=[0.00,0.25)', 'train/rots_vf_loss t=[0.50,0.75)', 'train/rots_vf_loss t=[0.75,1.00)', 'train/se3_vf_loss t=[0.00,0.25)', 'train/se3_vf_loss t=[0.50,0.75)', 'train/se3_vf_loss t=[0.75,1.00)', 'train/epoch_time_minutes', 'epoch', 'step']. HINT: Did you call `log('valid/non_coil_percent', value)` in the `LightningModule`?
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
I found that the error was indeed due to the validation metrics not being computed due to a bug, which got ignored by the try-except logic here: https://github.com/microsoft/protein-frame-flow/blob/main/models/flow_module.py#L177
In particular, there's no variable named CA_IDX in metrics. (https://github.com/microsoft/protein-frame-flow/blob/main/models/flow_module.py#L179)
My work around was to import residue_constants
and use that to get ca_idx
following how its done in the rest of the code. In the import statements, add:
from data import residue_constants
And change line 179 to:
ca_idx = residue_constants.atom_order['CA']
ca_ca_metrics = metrics.calc_ca_ca_metrics(final_pos[:, ca_idx])
This seems to work and I've opened a PR if you think this is the appropriate way to fix the issue.
Thanks for the catch and fix! This looks good to me.