microsoft / protein-frame-flow

Fast protein backbone generation with SE(3) flow matching.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Default setting for model checkpointer seems to break training

chaitjo opened this issue · comments

After making the fix in #17 and getting the training script to start running with the default configs, I'm encountering the following error wherein the code breaks after completing one training and one validation epoch. The error message suggests that the metric valid/non_coil_percent which is used for saving checkpoints was not logged/saved/computed.

Please let me know if there's a quick work around.

[2024-02-08 19:30:39,372][__main__][INFO] - Checkpoints saved to ckpt/se3-fm/baseline/2024-02-08_19-30-26
[2024-02-08 19:30:39,419][__main__][INFO] - Using devices: [0]
[2024-02-08 19:30:39,947][torch.distributed.distributed_c10d][INFO] - Added key: store_based_barrier_key:1 to store for rank: 0
[2024-02-08 19:30:39,948][torch.distributed.distributed_c10d][INFO] - Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 1 nodes.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------
[2024-02-08 19:30:41,168][data.pdb_dataloader][INFO] - Training: 3938 examples
[2024-02-08 19:30:41,207][data.pdb_dataloader][INFO] - Validation: 40 examples with lengths [ 20  38  53  68  83  98 113 128]
Sanity Checking DataLoader 0:   0%|                                                                                        | 0/2 [00:00<?, ?it/s]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  | Name  | Type      | Params
------------------------------------
0 | model | FlowModel | 16.7 M
------------------------------------
16.7 M    Trainable params
0         Non-trainable params
16.7 M    Total params


[2024-02-08 19:30:53,758][data.pdb_dataloader][INFO] - Created dataloader rank 1 out of 1
Epoch 0:   0%|                                                                                                          | 0/3938 [00:00<?, ?it/s][2024-02-08 19:30:56,139][data.so3_utils][INFO] - Data loaded from .cache/cache_igso3_s0.100-1.500-1000_l1000_o1000-3.npz
Epoch 0:   0%|                                 | 2/3938 [00:02<1:19:48,  1.22s/it, v_num=2dvk, train/examples_per_second=149.0, train/loss=13.70][2024-02-08 19:30:57,066][torch.nn.parallel.distributed][INFO] - Reducer buckets have been rebuilt in this iteration.

Epoch 1: 100%|████████████████████████████████| 3938/3938 [21:38<00:00,  3.03it/s, v_num=2dvk, train/examples_per_second=245.0, train/loss=2.520]

Validation DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [03:20<00:00,  5.01s/it]
Error executing job with overrides: []
Traceback (most recent call last):
  File "/home/ckj24/protein-frame-flow/experiments/train_se3_flows.py", line 97, in main
    exp.train()
  File "/home/ckj24/protein-frame-flow/experiments/train_se3_flows.py", line 72, in train
    trainer.fit(
  File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 532, in fit
    call._call_and_handle_interrupt(
  File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 42, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 93, in launch
    return function(*args, **kwargs)
  File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 571, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 980, in _run
    results = self._run_stage()
  File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1023, in _run_stage
    self.fit_loop.run()
  File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 202, in run
    self.advance()
  File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 355, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 134, in run
    self.on_advance_end()
  File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 249, in on_advance_end
    self.val_loop.run()
  File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py", line 181, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 122, in run
    return self.on_run_end()
  File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 258, in on_run_end
    self._on_evaluation_end()
  File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 303, in _on_evaluation_end
    call._call_callback_hooks(trainer, hook_name, *args, **kwargs)
  File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 194, in _call_callback_hooks
    fn(trainer, trainer.lightning_module, *args, **kwargs)
  File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 311, in on_validation_end
    self._save_topk_checkpoint(trainer, monitor_candidates)
  File "/home/ckj24/miniforge-pypy3/envs/fm/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 358, in _save_topk_checkpoint
    raise MisconfigurationException(m)
lightning_fabric.utilities.exceptions.MisconfigurationException: `ModelCheckpoint(monitor='valid/non_coil_percent')` could not find the monitored key in the returned metrics: ['train/bb_atom_loss', 'train/trans_loss', 'train/dist_mat_loss', 'train/auxiliary_loss', 'train/rots_vf_loss', 'train/se3_vf_loss', 'train/t', 'train/bb_atom_loss t=[0.25,0.50)', 'train/trans_loss t=[0.25,0.50)', 'train/dist_mat_loss t=[0.25,0.50)', 'train/auxiliary_loss t=[0.25,0.50)', 'train/rots_vf_loss t=[0.25,0.50)', 'train/se3_vf_loss t=[0.25,0.50)', 'train/length', 'train/batch_size', 'train/examples_per_second', 'train/loss', 'train/bb_atom_loss t=[0.00,0.25)', 'train/bb_atom_loss t=[0.50,0.75)', 'train/bb_atom_loss t=[0.75,1.00)', 'train/trans_loss t=[0.00,0.25)', 'train/trans_loss t=[0.50,0.75)', 'train/trans_loss t=[0.75,1.00)', 'train/dist_mat_loss t=[0.00,0.25)', 'train/dist_mat_loss t=[0.50,0.75)', 'train/dist_mat_loss t=[0.75,1.00)', 'train/auxiliary_loss t=[0.00,0.25)', 'train/auxiliary_loss t=[0.50,0.75)', 'train/auxiliary_loss t=[0.75,1.00)', 'train/rots_vf_loss t=[0.00,0.25)', 'train/rots_vf_loss t=[0.50,0.75)', 'train/rots_vf_loss t=[0.75,1.00)', 'train/se3_vf_loss t=[0.00,0.25)', 'train/se3_vf_loss t=[0.50,0.75)', 'train/se3_vf_loss t=[0.75,1.00)', 'train/epoch_time_minutes', 'epoch', 'step']. HINT: Did you call `log('valid/non_coil_percent', value)` in the `LightningModule`?
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

I found that the error was indeed due to the validation metrics not being computed due to a bug, which got ignored by the try-except logic here: https://github.com/microsoft/protein-frame-flow/blob/main/models/flow_module.py#L177

In particular, there's no variable named CA_IDX in metrics. (https://github.com/microsoft/protein-frame-flow/blob/main/models/flow_module.py#L179)

My work around was to import residue_constants and use that to get ca_idx following how its done in the rest of the code. In the import statements, add:

from data import residue_constants

And change line 179 to:

ca_idx = residue_constants.atom_order['CA']
ca_ca_metrics = metrics.calc_ca_ca_metrics(final_pos[:, ca_idx])

This seems to work and I've opened a PR if you think this is the appropriate way to fix the issue.

Thanks for the catch and fix! This looks good to me.