ACEsuit / mace

MACE - Fast and accurate machine learning interatomic potentials with higher order equivariant message passing.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Fine-tunning up-to-date MACE-MP-0 model

bfocassio opened this issue · comments

Describe the bug
Dear developers,

I've been trying to fine-tune a large mace-mp-0 model. However I'm running into some problems.

I'm using the foundational branch

First, what works?

Using the following training script input works:

python /home/bruno.focassio/codes/mace-foundations/scripts/run_train.py \ --name="mace_fine_tunning_100" \ --foundation_model="large" \ --train_file="training_data.xyz" \ --test_file="test_data.xyz" \ --valid_fraction=0.05 \ --energy_weight=1 \ --forces_weight=10 \ --compute_stress=True \ --stress_weight=100 \ --stress_key='stress' \ --eval_interval=1 \ --error_table='PerAtomMAE' \ --E0s="average" \ --interaction_first="RealAgnosticResidualInteractionBlock" \ --interaction="RealAgnosticResidualInteractionBlock" \ --scaling='rms_forces_scaling' \ --lr=0.005 \ --weight_decay=1e-8 \ --ema \ --ema_decay=0.995 \ --scheduler_patience=5 \ --batch_size=2 \ --valid_batch_size=4 \ --max_num_epochs=100 \ --patience=20 \ --amsgrad \ --device="cuda" \ --seed=1 \ --clip_grad=100 \ --keep_checkpoints \ --restart_latest \ --save_cpu
And you can check the log:

2024-02-15 17:03:47.667 INFO: MACE version: 0.3.4
2024-02-15 17:03:47.667 INFO: Configuration: Namespace(name='mace_fine_tunning_100', seed=1, log_dir='logs', 
model_dir='.', checkpoints_dir='checkpoints', results_dir='results', downloads_dir='downloads', device='cuda', default_dtype='float64', log_level='INFO', error_table='PerAtomMAE', model='MACE', r_max=5.0, radial_type='bessel', num_radial_basis=8, num_cutoff_basis=5, pair_repulsion=False, distance_transform=False, interaction='RealAgnosticResidualInteractionBlock', interaction_first='RealAgnosticResidualInteractionBlock', max_ell=3, correlation=3, num_interactions=2, MLP_irreps='16x0e', radial_MLP='[64, 64, 64]', hidden_irreps='128x0e + 128x1o', num_channels=None, max_L=None, gate='silu', scaling='rms_forces_scaling', avg_num_neighbors=1, compute_avg_num_neighbors=True, compute_stress=True, compute_forces=True, train_file='/home/bruno.focassio/mace_large_model/uip/train_universal_foundation/jpca_full_training_data.xyz', valid_file=None, valid_fraction=0.05, test_file='/home/bruno.focassio/mace_large_model/uip/train_universal_foundation/jpca_full_test_data.xyz', E0s='average', energy_key='energy', forces_key='forces', virials_key='virials', stress_key='stress', dipole_key='dipole', charges_key='charges', loss='weighted', forces_weight=10.0, swa_forces_weight=100.0, energy_weight=1.0, swa_energy_weight=1000.0, virials_weight=1.0, swa_virials_weight=10.0, stress_weight=100.0, swa_stress_weight=10.0, dipole_weight=1.0, swa_dipole_weight=1.0, config_type_weights='{"Default":1.0}', huber_delta=0.01, optimizer='adam', batch_size=2, valid_batch_size=4, lr=0.005, swa_lr=0.001, weight_decay=1e-08, amsgrad=True, scheduler='ReduceLROnPlateau', lr_factor=0.8, scheduler_patience=5, lr_scheduler_gamma=0.9993, swa=False, start_swa=None, ema=True, ema_decay=0.995, max_num_epochs=100, patience=20, foundation_model='large', foundation_model_readout=True, eval_interval=1, keep_checkpoints=True, restart_latest=True, save_cpu=True, clip_grad=100.0, wandb=False, wandb_project='', wandb_entity='', wandb_name='', wandb_log_hypers=['num_channels', 'max_L', 'correlation', 'lr', 'swa_lr', 'weight_decay', 'batch_size', 'max_num_epochs', 'start_swa', 'energy_weight', 'forces_weight'])
2024-02-15 17:03:47.710 INFO: CUDA version: 11.3, CUDA device: 0
2024-02-15 17:03:48.783 INFO: Loaded 1402 training configurations from '/home/bruno.focassio/mace_large_model/uip/train_universal_foundation/jpca_full_training_data.xyz'
2024-02-15 17:03:48.784 INFO: Using random 5.0% of training set for validation
2024-02-15 17:03:48.880 INFO: Loaded 164 test configurations from '/home/bruno.focassio/mace_large_model/uip/train_universal_foundation/jpca_full_test_data.xyz'
2024-02-15 17:03:48.881 INFO: Total number of configurations: train=1332, valid=70, tests=[Default: 164]
2024-02-15 17:03:48.897 INFO: AtomicNumberTable: (3, 14, 28, 29, 32, 42)
2024-02-15 17:03:48.898 INFO: Atomic Energies not in training file, using command line argument E0s
2024-02-15 17:03:48.898 INFO: Computing average Atomic Energies using least squares regression
2024-02-15 17:03:49.098 INFO: Atomic energies: [-1.8136614652719867, -5.115625774505135, -5.649542038228693, -3.975118056170312, -4.235264403697451, -10.440083603305574]
2024-02-15 17:03:52.187 INFO: WeightedEnergyForcesLoss(energy_weight=1.000, forces_weight=10.000)
2024-02-15 17:03:52.847 INFO: Average number of neighbors: 36.90208147728193
2024-02-15 17:03:52.848 INFO: Selected the following outputs: {'energy': True, 'forces': True, 'virials': False, 'stress': True, 'dipoles': False}
2024-02-15 17:03:52.849 INFO: Building model
2024-02-15 17:03:52.855 INFO: Using large mace-mp-0 settings. Hidden irreps: 128x0e+128x1o+128x2e
2024-02-15 17:57:48.889 INFO: MACE version: 0.3.4
2024-02-15 17:57:48.898 INFO: Configuration: Namespace(name='mace_fine_tunning_100', seed=1, log_dir='logs', model_dir='.', checkpoints_dir='checkpoints', results_dir='results', downloads_dir='downloads', device='cuda', default_dtype='float64', log_level='INFO', error_table='PerAtomMAE', model='MACE', r_max=5.0, radial_type='bessel', num_radial_basis=8, num_cutoff_basis=5, pair_repulsion=False, distance_transform=False, interaction='RealAgnosticResidualInteractionBlock', interaction_first='RealAgnosticResidualInteractionBlock', max_ell=3, correlation=3, num_interactions=2, MLP_irreps='16x0e', radial_MLP='[64, 64, 64]', hidden_irreps='128x0e + 128x1o', num_channels=None, max_L=None, gate='silu', scaling='rms_forces_scaling', avg_num_neighbors=1, compute_avg_num_neighbors=True, compute_stress=True, compute_forces=True, train_file='training_data.xyz', valid_file=None, valid_fraction=0.05, test_file='test_data.xyz', E0s='average', energy_key='energy', forces_key='forces', virials_key='virials', stress_key='stress', dipole_key='dipole', charges_key='charges', loss='weighted', forces_weight=10.0, swa_forces_weight=100.0, energy_weight=1.0, swa_energy_weight=1000.0, virials_weight=1.0, swa_virials_weight=10.0, stress_weight=100.0, swa_stress_weight=10.0, dipole_weight=1.0, swa_dipole_weight=1.0, config_type_weights='{"Default":1.0}', huber_delta=0.01, optimizer='adam', batch_size=2, valid_batch_size=4, lr=0.005, swa_lr=0.001, weight_decay=1e-08, amsgrad=True, scheduler='ReduceLROnPlateau', lr_factor=0.8, scheduler_patience=5, lr_scheduler_gamma=0.9993, swa=False, start_swa=None, ema=True, ema_decay=0.995, max_num_epochs=100, patience=20, foundation_model='large', foundation_model_readout=True, eval_interval=1, keep_checkpoints=True, restart_latest=True, save_cpu=True, clip_grad=100.0, wandb=False, wandb_project='', wandb_entity='', wandb_name='', wandb_log_hypers=['num_channels', 'max_L', 'correlation', 'lr', 'swa_lr', 'weight_decay', 'batch_size', 'max_num_epochs', 'start_swa', 'energy_weight', 'forces_weight'])
2024-02-15 17:57:48.955 INFO: CUDA version: 11.3, CUDA device: 0
2024-02-15 17:57:50.074 INFO: Loaded 1402 training configurations from '/home/bruno.focassio/mace_large_model/uip/train_universal_foundation/jpca_full_training_data.xyz'
2024-02-15 17:57:50.076 INFO: Using random 5.0% of training set for validation
2024-02-15 17:57:50.186 INFO: Loaded 164 test configurations from '/home/bruno.focassio/mace_large_model/uip/train_universal_foundation/jpca_full_test_data.xyz'
2024-02-15 17:57:50.187 INFO: Total number of configurations: train=1332, valid=70, tests=[Default: 164]
2024-02-15 17:57:50.203 INFO: AtomicNumberTable: (3, 14, 28, 29, 32, 42)
2024-02-15 17:57:50.204 INFO: Atomic Energies not in training file, using command line argument E0s
2024-02-15 17:57:50.204 INFO: Computing average Atomic Energies using least squares regression
2024-02-15 17:57:50.558 INFO: Atomic energies: [-1.8136614652719867, -5.115625774505135, -5.649542038228693, -3.975118056170312, -4.235264403697451, -10.440083603305574]
2024-02-15 17:57:53.275 INFO: WeightedEnergyForcesLoss(energy_weight=1.000, forces_weight=10.000)
2024-02-15 17:57:53.989 INFO: Average number of neighbors: 36.90208147728193
2024-02-15 17:57:53.990 INFO: Selected the following outputs: {'energy': True, 'forces': True, 'virials': False, 'stress': True, 'dipoles': False}
2024-02-15 17:57:53.990 INFO: Building model
2024-02-15 17:57:53.997 INFO: Using large mace-mp-0 settings. Hidden irreps: 128x0e+128x1o+128x2e
2024-02-15 17:58:03.137 INFO: CUDA version: 11.3, CUDA device: 0
2024-02-15 17:58:03.145 INFO: Using foundation model large as initial checkpoint.
2024-02-15 17:58:03.183 WARNING: Cannot find checkpoint with tag 'mace_fine_tunning_100_run-1' in 'checkpoints'
2024-02-15 17:58:03.190 INFO: ScaleShiftMACE(
  (node_embedding): LinearNodeEmbeddingBlock(
    (linear): Linear(6x0e -> 128x0e | 768 weights)
  )
  (radial_embedding): RadialEmbeddingBlock(
    (bessel_fn): BesselBasis(r_max=6.0, num_basis=10, trainable=False)
    (cutoff_fn): PolynomialCutoff(p=5.0, r_max=6.0)
  )
  (spherical_harmonics): SphericalHarmonics()
  (atomic_energies_fn): AtomicEnergiesBlock(energies=[-1.8137, -5.1156, -5.6495, -3.9751, -4.2353, -10.4401])
  (interactions): ModuleList(
    (0): RealAgnosticResidualInteractionBlock(
      (linear_up): Linear(128x0e -> 128x0e | 16384 weights)
      (conv_tp): TensorProduct(128x0e x 1x0e+1x1o+1x2e+1x3o -> 128x0e+128x1o+128x2e+128x3o | 512 paths | 512 weights)
      (conv_tp_weights): FullyConnectedNet[10, 64, 64, 64, 512]
      (linear): Linear(128x0e+128x1o+128x2e+128x3o -> 128x0e+128x1o+128x2e+128x3o | 65536 weights)
      (skip_tp): FullyConnectedTensorProduct(128x0e x 6x0e -> 128x0e+128x1o+128x2e | 98304 paths | 98304 weights)
      (reshape): reshape_irreps()
    )
    (1): RealAgnosticResidualInteractionBlock(
      (linear_up): Linear(128x0e+128x1o+128x2e -> 128x0e+128x1o+128x2e | 49152 weights)
      (conv_tp): TensorProduct(128x0e+128x1o+128x2e x 1x0e+1x1o+1x2e+1x3o -> 384x0e+640x1o+640x2e+512x3o | 2176 paths | 2176 weights)
      (conv_tp_weights): FullyConnectedNet[10, 64, 64, 64, 2176]
      (linear): Linear(384x0e+640x1o+640x2e+512x3o -> 128x0e+128x1o+128x2e+128x3o | 278528 weights)
      (skip_tp): FullyConnectedTensorProduct(128x0e+128x1o+128x2e x 6x0e -> 128x0e | 98304 paths | 98304 weights)
      (reshape): reshape_irreps()
    )
  )
  (products): ModuleList(
    (0): EquivariantProductBasisBlock(
      (symmetric_contractions): SymmetricContraction(
        (contractions): ModuleList(
          (0): Contraction(
            (contractions_weighting): ModuleList(
              (0): GraphModule()
              (1): GraphModule()
            )
            (contractions_features): ModuleList(
              (0): GraphModule()
              (1): GraphModule()
            )
            (weights): ParameterList(
                (0): Parameter containing: [torch.cuda.DoubleTensor of size 6x4x128 (GPU 0)]
                (1): Parameter containing: [torch.cuda.DoubleTensor of size 6x1x128 (GPU 0)]
            )
            (graph_opt_main): GraphModule()
          )
          (1): Contraction(
            (contractions_weighting): ModuleList(
              (0): GraphModule()
              (1): GraphModule()
            )
            (contractions_features): ModuleList(
              (0): GraphModule()
              (1): GraphModule()
            )
            (weights): ParameterList(
                (0): Parameter containing: [torch.cuda.DoubleTensor of size 6x6x128 (GPU 0)]
                (1): Parameter containing: [torch.cuda.DoubleTensor of size 6x1x128 (GPU 0)]
            )
            (graph_opt_main): GraphModule()
          )
          (2): Contraction(
            (contractions_weighting): ModuleList(
              (0): GraphModule()
              (1): GraphModule()
            )
            (contractions_features): ModuleList(
              (0): GraphModule()
              (1): GraphModule()
            )
            (weights): ParameterList(
                (0): Parameter containing: [torch.cuda.DoubleTensor of size 6x7x128 (GPU 0)]
                (1): Parameter containing: [torch.cuda.DoubleTensor of size 6x1x128 (GPU 0)]
            )
            (graph_opt_main): GraphModule()
          )
        )
      )
      (linear): Linear(128x0e+128x1o+128x2e -> 128x0e+128x1o+128x2e | 49152 weights)
    )
    (1): EquivariantProductBasisBlock(
      (symmetric_contractions): SymmetricContraction(
        (contractions): ModuleList(
          (0): Contraction(
            (contractions_weighting): ModuleList(
              (0): GraphModule()
              (1): GraphModule()
            )
            (contractions_features): ModuleList(
              (0): GraphModule()
              (1): GraphModule()
            )
            (weights): ParameterList(
                (0): Parameter containing: [torch.cuda.DoubleTensor of size 6x4x128 (GPU 0)]
                (1): Parameter containing: [torch.cuda.DoubleTensor of size 6x1x128 (GPU 0)]
            )
            (graph_opt_main): GraphModule()
          )
        )
      )
      (linear): Linear(128x0e -> 128x0e | 16384 weights)
    )
  )
  (readouts): ModuleList(
    (0): LinearReadoutBlock(
      (linear): Linear(128x0e+128x1o+128x2e -> 1x0e | 128 weights)
    )
    (1): NonLinearReadoutBlock(
      (linear_1): Linear(128x0e -> 16x0e | 2048 weights)
      (non_linearity): Activation [x] (16x0e -> 16x0e)
      (linear_2): Linear(16x0e -> 1x0e | 16 weights)
    )
  )
  (scale_shift): ScaleShiftBlock(scale=0.804154, shift=-0.001485)
)
2024-02-15 17:58:03.197 INFO: Number of parameters: 1008016
2024-02-15 17:58:03.197 INFO: Optimizer: Adam (
Parameter Group 0
    amsgrad: True
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.005
    maximize: False
    name: embedding
    weight_decay: 0.0

Parameter Group 1
    amsgrad: True
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.005
    maximize: False
    name: interactions_decay
    weight_decay: 1e-08

Parameter Group 2
    amsgrad: True
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.005
    maximize: False
    name: interactions_no_decay
    weight_decay: 0.0

Parameter Group 3
    amsgrad: True
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.005
    maximize: False
    name: products
    weight_decay: 1e-08

Parameter Group 4
    amsgrad: True
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.005
    maximize: False
    name: readouts
    weight_decay: 0.0

Parameter Group 5
    amsgrad: True
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.005
    maximize: False
    name: radial_embedding
    weight_decay: 0.0
)
2024-02-15 17:58:03.198 INFO: Using gradient clipping with tolerance=100.000
2024-02-15 17:58:03.198 INFO: Started training
2024-02-15 17:58:17.348 INFO: Epoch None: loss=2.2620, MAE_E_per_atom=1073.7 meV, MAE_F=112.0 meV / A
2024-02-15 18:02:18.942 INFO: Epoch 0: loss=0.1095, MAE_E_per_atom=26.6 meV, MAE_F=47.0 meV / A
2024-02-15 18:06:07.018 INFO: Epoch 1: loss=0.0982, MAE_E_per_atom=23.1 meV, MAE_F=45.6 meV / A
2024-02-15 18:09:54.945 INFO: Epoch 2: loss=0.0881, MAE_E_per_atom=19.8 meV, MAE_F=42.2 meV / A
2024-02-15 18:13:42.861 INFO: Epoch 3: loss=0.0777, MAE_E_per_atom=14.7 meV, MAE_F=38.6 meV / A
2024-02-15 18:17:30.534 INFO: Epoch 4: loss=0.0679, MAE_E_per_atom=13.0 meV, MAE_F=36.0 meV / A
2024-02-15 18:21:18.165 INFO: Epoch 5: loss=0.0621, MAE_E_per_atom=12.8 meV, MAE_F=34.5 meV / A
2024-02-15 18:25:05.720 INFO: Epoch 6: loss=0.0579, MAE_E_per_atom=11.4 meV, MAE_F=33.2 meV / A
2024-02-15 18:28:53.007 INFO: Epoch 7: loss=0.0555, MAE_E_per_atom=11.8 meV, MAE_F=32.7 meV / A
2024-02-15 18:32:40.528 INFO: Epoch 8: loss=0.0527, MAE_E_per_atom=10.6 meV, MAE_F=31.7 meV / A
2024-02-15 18:36:27.989 INFO: Epoch 9: loss=0.0532, MAE_E_per_atom=12.7 meV, MAE_F=31.8 meV / A
2024-02-15 18:40:15.416 INFO: Epoch 10: loss=0.0511, MAE_E_per_atom=10.2 meV, MAE_F=30.9 meV / A
2024-02-15 18:44:02.989 INFO: Epoch 11: loss=0.0497, MAE_E_per_atom=10.0 meV, MAE_F=30.5 meV / A
2024-02-15 18:47:50.506 INFO: Epoch 12: loss=0.0576, MAE_E_per_atom=13.5 meV, MAE_F=32.7 meV / A
2024-02-15 18:51:37.923 INFO: Epoch 13: loss=0.0488, MAE_E_per_atom=9.4 meV, MAE_F=29.8 meV / A
2024-02-15 18:55:25.431 INFO: Epoch 14: loss=0.0474, MAE_E_per_atom=8.4 meV, MAE_F=29.3 meV / A
2024-02-15 18:59:12.862 INFO: Epoch 15: loss=0.0458, MAE_E_per_atom=9.2 meV, MAE_F=28.9 meV / A
2024-02-15 19:03:00.278 INFO: Epoch 16: loss=0.0486, MAE_E_per_atom=8.7 meV, MAE_F=29.2 meV / A
2024-02-15 19:06:47.598 INFO: Epoch 17: loss=0.0458, MAE_E_per_atom=9.9 meV, MAE_F=29.0 meV / A
2024-02-15 19:10:35.052 INFO: Epoch 18: loss=0.0452, MAE_E_per_atom=9.3 meV, MAE_F=29.0 meV / A
2024-02-15 19:14:22.564 INFO: Epoch 19: loss=0.0439, MAE_E_per_atom=8.1 meV, MAE_F=27.9 meV / A
2024-02-15 19:18:10.252 INFO: Epoch 20: loss=0.0435, MAE_E_per_atom=7.9 meV, MAE_F=27.8 meV / A
2024-02-15 19:21:58.010 INFO: Epoch 21: loss=0.0431, MAE_E_per_atom=7.2 meV, MAE_F=27.6 meV / A
2024-02-15 19:25:45.505 INFO: Epoch 22: loss=0.0425, MAE_E_per_atom=8.8 meV, MAE_F=27.6 meV / A
2024-02-15 19:29:32.864 INFO: Epoch 23: loss=0.0420, MAE_E_per_atom=7.8 meV, MAE_F=27.3 meV / A
2024-02-15 19:33:20.227 INFO: Epoch 24: loss=0.0426, MAE_E_per_atom=7.2 meV, MAE_F=27.4 meV / A
2024-02-15 19:37:08.055 INFO: Epoch 25: loss=0.0436, MAE_E_per_atom=7.4 meV, MAE_F=27.7 meV / A
2024-02-15 19:40:55.733 INFO: Epoch 26: loss=0.0425, MAE_E_per_atom=7.0 meV, MAE_F=27.0 meV / A
2024-02-15 19:44:43.321 INFO: Epoch 27: loss=0.0432, MAE_E_per_atom=6.5 meV, MAE_F=27.3 meV / A
2024-02-15 19:48:30.803 INFO: Epoch 28: loss=0.0428, MAE_E_per_atom=7.1 meV, MAE_F=27.0 meV / A
2024-02-15 19:52:18.400 INFO: Epoch 29: loss=0.0419, MAE_E_per_atom=6.9 meV, MAE_F=26.8 meV / A
2024-02-15 19:56:06.107 INFO: Epoch 30: loss=0.0424, MAE_E_per_atom=7.6 meV, MAE_F=27.0 meV / A
2024-02-15 19:59:53.625 INFO: Epoch 31: loss=0.0423, MAE_E_per_atom=6.7 meV, MAE_F=26.8 meV / A
2024-02-15 20:03:41.053 INFO: Epoch 32: loss=0.0421, MAE_E_per_atom=6.8 meV, MAE_F=26.5 meV / A
2024-02-15 20:07:28.499 INFO: Epoch 33: loss=0.0494, MAE_E_per_atom=10.1 meV, MAE_F=30.5 meV / A
2024-02-15 20:11:15.947 INFO: Epoch 34: loss=0.0441, MAE_E_per_atom=8.3 meV, MAE_F=27.8 meV / A
2024-02-15 20:15:04.107 INFO: Epoch 35: loss=0.0429, MAE_E_per_atom=7.5 meV, MAE_F=26.7 meV / A
2024-02-15 20:18:51.550 INFO: Epoch 36: loss=0.0428, MAE_E_per_atom=7.0 meV, MAE_F=26.7 meV / A
2024-02-15 20:22:39.116 INFO: Epoch 37: loss=0.0429, MAE_E_per_atom=6.3 meV, MAE_F=26.5 meV / A
2024-02-15 20:26:26.573 INFO: Epoch 38: loss=0.0432, MAE_E_per_atom=6.2 meV, MAE_F=26.5 meV / A
2024-02-15 20:30:14.166 INFO: Epoch 39: loss=0.0427, MAE_E_per_atom=6.5 meV, MAE_F=26.4 meV / A
2024-02-15 20:34:01.645 INFO: Epoch 40: loss=0.0428, MAE_E_per_atom=6.5 meV, MAE_F=26.5 meV / A
2024-02-15 20:37:49.177 INFO: Epoch 41: loss=0.0428, MAE_E_per_atom=6.0 meV, MAE_F=26.3 meV / A
2024-02-15 20:41:36.573 INFO: Epoch 42: loss=0.0426, MAE_E_per_atom=6.0 meV, MAE_F=26.2 meV / A
2024-02-15 20:45:24.005 INFO: Epoch 43: loss=0.0425, MAE_E_per_atom=6.0 meV, MAE_F=26.1 meV / A
2024-02-15 20:49:11.417 INFO: Epoch 44: loss=0.0423, MAE_E_per_atom=5.9 meV, MAE_F=26.2 meV / A
2024-02-15 20:52:58.868 INFO: Epoch 45: loss=0.0428, MAE_E_per_atom=5.9 meV, MAE_F=26.1 meV / A
2024-02-15 20:56:46.500 INFO: Epoch 46: loss=0.0430, MAE_E_per_atom=6.1 meV, MAE_F=26.2 meV / A
2024-02-15 21:00:34.019 INFO: Epoch 47: loss=0.0427, MAE_E_per_atom=5.9 meV, MAE_F=26.2 meV / A
2024-02-15 21:04:21.534 INFO: Epoch 48: loss=0.0434, MAE_E_per_atom=5.5 meV, MAE_F=26.1 meV / A
2024-02-15 21:08:08.908 INFO: Epoch 49: loss=0.0431, MAE_E_per_atom=5.8 meV, MAE_F=26.1 meV / A
2024-02-15 21:08:08.910 INFO: Stopping optimization after 20 epochs without improvement
2024-02-15 21:08:08.911 INFO: Training complete
2024-02-15 21:08:08.913 INFO: Computing metrics for training, validation, and test sets
2024-02-15 21:08:08.954 INFO: Loading checkpoint: checkpoints/mace_fine_tunning_100_run-1_epoch-29.pt
2024-02-15 21:08:09.369 INFO: Loaded model from epoch 29
2024-02-15 21:08:12.051 INFO: Evaluating train ...
2024-02-15 21:09:16.921 INFO: Evaluating valid ...
2024-02-15 21:09:20.502 INFO: Evaluating Default ...
2024-02-15 21:09:28.334 INFO: 
+-------------+--------------------+-----------------+------------------+
| config_type | MAE E / meV / atom | MAE F / meV / A | relative F MAE % |
+-------------+--------------------+-----------------+------------------+
|    train    |        5.1         |       20.9      |       4.46       |
|    valid    |        6.9         |       26.8      |       5.59       |
|   Default   |        5.5         |       26.5      |       5.34       |
+-------------+--------------------+-----------------+------------------+
2024-02-15 21:09:28.336 INFO: Saving model to checkpoints/mace_fine_tunning_100_run-1.model
2024-02-15 21:09:28.707 INFO: Done

There are a couple of questions from this:

  1. It seems it loads the large model from large="http://tinyurl.com/5f5yavf3", # MACE_MPtrj_2022.9.model , however I find that the 0.3.3 release uses the model from large="https://figshare.com/ndownloader/files/43117273",

In that regard, how can I use the most up-to-date one? I've tried replacing the --foundation_model="large" \ by the path of several different models, including the one from https://figshare.com/ndownloader/files/43117273 and even the one available on Hugging Face: 2024-01-07-mace-128-L2_epoch-199.model

  1. The linear embedding block (first one) its showing:
    (node_embedding): LinearNodeEmbeddingBlock( (linear): Linear(6x0e -> 128x0e | 768 weights) )
    because my fine-tunning training set only has 6 elements, however the full large mace-mp-0 model is supposed to be
    (node_embedding): LinearNodeEmbeddingBlock( (linear): Linear(89x0e -> 128x0e | 11392 weights) )
    How can I keep the original shape of the linear embedding? And only fine-tune for the elements on my training set? Should I create dummy samples with the single atoms with the average atomic energy?

  2. I have tried to use the 2024-01-07-mace-128-L2_epoch-199 model, however, when I try that I run into very similar problems. When I try to use the checkpoint available on hugging-face for that model, it gives me a size mismatch between the model I'm loading and the model from the checkpoint, which I suspect is something related to the above questions.

  3. In that training that succeeded, is it using the already trained model, with current weights? or is it using the scaffold of the large model to start a fresh training? I expect that the correct output would start from some epoch that is not the 0th one, so that fine-tuning is actually just training not fine-tunning, is that correct?

Any help is appreciated

Hey @bfocassio,

  1. Indeed you will need to put the path of the model you want to finetune, however if you do that you will need to give the full hypers of the model. You cam find them here https://github.com/ACEsuit/mace-mp/tree/main/mace_mp_0.

  2. It selects only at the atoms in your training set for two reasons: to not keep a large model that is memory intensive even for small dataset, and also because of isolated atom energy. If you want to keep an element, then you need to put a dummy isolated atom with the right isolated atom energy to keep it.

  3. That is due to not providing the right input hypers. See my first response to see the hypers.

  4. I am not sure I understand this question. The fine tuning currently fine tunes all the mace-mp model. It might change if we find better finetuning protocols.

Hi @ilyes319 ,

Ok, I understand.

So, just to try using the 2024-01-07-mace-128-L2_epoch-199.model for fine-tunning. I've created dummy structures, each with 2 atoms, for all the 89 different atomic species from the MPTrj dataset.

Here is the command for training:

python /home/bruno.focassio/codes/mace-foundations/scripts/run_train.py \
    --name="2024-01-07-mace-128-L2" \
    --foundation_model="2024-01-07-mace-128-L2.model" \
    --train_file="training_data.xyz" \
    --test_file="test_data.xyz" \
    --valid_fraction=0.05 \
    --loss="universal" \
    --energy_weight=1 \
    --forces_weight=10 \
    --compute_stress=True \
    --stress_weight=100 \
    --stress_key='stress' \
    --eval_interval=1 \
    --error_table='PerAtomMAE' \
    --E0s="{1: -3.667168021358939, 2: -1.3320953124042916, 3: -3.482100566595956, 4: -4.736697230897597, 5: -7.724935420523256, 6: -8.405573550273285, 7: -7.360100452662763, 8: -7.28459863421322, 9: -4.896490881731322, 10: 1.3917755836700962e-12, 11: -2.7593613569762425, 12: -2.814047612069227, 13: -4.846881245288104, 14: -7.694793133351899, 15: -6.9632957911820235, 16: -4.672630400190884, 17: -2.8116892814008096, 18: -0.06259504416367478, 19: -2.6176454856894793, 20: -5.390461060484104, 21: -7.8857952163517675, 22: -10.268392986214433, 23: -8.665147785496703, 24: -9.233050763772013, 25: -8.304951520770791, 26: -7.0489865771593765, 27: -5.577439766222147, 28: -5.172747618813715, 29: -3.2520726958619472, 30: -1.2901611618726314, 31: -3.527082192997912, 32: -4.70845955030298, 33: -3.9765109025623238, 34: -3.886231055836541, 35: -2.5184940099633986, 36: 6.766947645687137, 37: -2.5634958965928316, 38: -4.938005211501922, 39: -10.149818838085771, 40: -11.846857579882572, 41: -12.138896361658485, 42: -8.791678800595722, 43: -8.78694939675911, 44: -7.78093221529871, 45: -6.850021409115055, 46: -4.891019073240479, 47: -2.0634296773864045, 48: -0.6395695518943755, 49: -2.7887442084286693, 50: -3.818604275441892, 51: -3.587068329278862, 52: -2.8804045971118897, 53: -1.6355986842433357, 54: 9.846723842807721, 55: -2.765284507132287, 56: -4.990956432167774, 57: -8.933684809576345, 58: -8.735591176647514, 59: -8.018966025544966, 60: -8.251491970213372, 61: -7.591719594359237, 62: -8.169659881166858, 63: -13.592664636171698, 64: -18.517523458456985, 65: -7.647396572993602, 66: -8.122981037851925, 67: -7.607787319678067, 68: -6.85029094445494, 69: -7.8268821327130365, 70: -3.584786591677161, 71: -7.455406192077973, 72: -12.796283502572146, 73: -14.108127281277586, 74: -9.354916969477486, 75: -11.387537567890853, 76: -9.621909492152557, 77: -7.324393429417677, 78: -5.3046964808341945, 79: -2.380092582080244, 80: 0.24948924158195362, 81: -2.3239789120665026, 82: -3.730042357127322, 83: -3.438792347649683, 89: -5.062878214511315, 90: -11.02462566385297, 91: -12.265613551943261, 92: -13.855648206100362, 93: -14.933092020258243, 94: -15.282826131998245}" \
    --interaction_first="RealAgnosticResidualInteractionBlock" \
    --interaction="RealAgnosticResidualInteractionBlock" \
    --num_interactions=2 \
    --correlation=3 \
    --max_ell=3 \
    --r_max=6.0 \
    --max_L=2 \
    --num_channels=128 \
    --num_radial_basis=10 \
    --MLP_irreps="16x0e" \
    --scaling='rms_forces_scaling' \
    --lr=0.005 \
    --weight_decay=1e-8 \
    --ema \
    --ema_decay=0.995 \
    --scheduler_patience=5 \
    --batch_size=2 \
    --valid_batch_size=4 \
    --max_num_epochs=300 \
    --patience=30 \
    --amsgrad \
    --device="cuda" \
    --default_dtype="float64" \
    --seed=1 \
    --clip_grad=100 \
    --keep_checkpoints \
    --restart_latest \
    --save_cpu

Notice I kept the average atomic energies from the training input (ACEsuit/mace-mp#1 (comment))

I've created a checkpoints folder with the checkpoint downloaded from hugging face

Here is the log with the error:

2024-02-19 16:15:27.752 INFO: MACE version: 0.3.4
2024-02-19 16:15:27.753 INFO: Configuration: Namespace(name='2024-01-07-mace-128-L2', seed=1, log_dir='logs', model_dir='.', checkpoints_dir='checkpoints', results_dir='results', downloads_dir='downloads', device='cuda', default_dtype='float64', log_level='INFO', error_table='PerAtomMAE', model='MACE', r_max=6.0, radial_type='bessel', num_radial_basis=10, num_cutoff_basis=5, pair_repulsion=False, distance_transform=False, interaction='RealAgnosticResidualInteractionBlock', interaction_first='RealAgnosticResidualInteractionBlock', max_ell=3, correlation=3, num_interactions=2, MLP_irreps='16x0e', radial_MLP='[64, 64, 64]', hidden_irreps='128x0e + 128x1o', num_channels=128, max_L=2, gate='silu', scaling='rms_forces_scaling', avg_num_neighbors=1, compute_avg_num_neighbors=True, compute_stress=True, compute_forces=True, train_file='/home/bruno.focassio/mace_large_model/uip/fine_tune_2024_example/training_data.xyz', valid_file=None, valid_fraction=0.05, test_file='/home/bruno.focassio/mace_large_model/uip/fine_tune_2024_example/test_data.xyz', E0s='{1: -3.667168021358939, 2: -1.3320953124042916, 3: -3.482100566595956, 4: -4.736697230897597, 5: -7.724935420523256, 6: -8.405573550273285, 7: -7.360100452662763, 8: -7.28459863421322, 9: -4.896490881731322, 10: 1.3917755836700962e-12, 11: -2.7593613569762425, 12: -2.814047612069227, 13: -4.846881245288104, 14: -7.694793133351899, 15: -6.9632957911820235, 16: -4.672630400190884, 17: -2.8116892814008096, 18: -0.06259504416367478, 19: -2.6176454856894793, 20: -5.390461060484104, 21: -7.8857952163517675, 22: -10.268392986214433, 23: -8.665147785496703, 24: -9.233050763772013, 25: -8.304951520770791, 26: -7.0489865771593765, 27: -5.577439766222147, 28: -5.172747618813715, 29: -3.2520726958619472, 30: -1.2901611618726314, 31: -3.527082192997912, 32: -4.70845955030298, 33: -3.9765109025623238, 34: -3.886231055836541, 35: -2.5184940099633986, 36: 6.766947645687137, 37: -2.5634958965928316, 38: -4.938005211501922, 39: -10.149818838085771, 40: -11.846857579882572, 41: -12.138896361658485, 42: -8.791678800595722, 43: -8.78694939675911, 44: -7.78093221529871, 45: -6.850021409115055, 46: -4.891019073240479, 47: -2.0634296773864045, 48: -0.6395695518943755, 49: -2.7887442084286693, 50: -3.818604275441892, 51: -3.587068329278862, 52: -2.8804045971118897, 53: -1.6355986842433357, 54: 9.846723842807721, 55: -2.765284507132287, 56: -4.990956432167774, 57: -8.933684809576345, 58: -8.735591176647514, 59: -8.018966025544966, 60: -8.251491970213372, 61: -7.591719594359237, 62: -8.169659881166858, 63: -13.592664636171698, 64: -18.517523458456985, 65: -7.647396572993602, 66: -8.122981037851925, 67: -7.607787319678067, 68: -6.85029094445494, 69: -7.8268821327130365, 70: -3.584786591677161, 71: -7.455406192077973, 72: -12.796283502572146, 73: -14.108127281277586, 74: -9.354916969477486, 75: -11.387537567890853, 76: -9.621909492152557, 77: -7.324393429417677, 78: -5.3046964808341945, 79: -2.380092582080244, 80: 0.24948924158195362, 81: -2.3239789120665026, 82: -3.730042357127322, 83: -3.438792347649683, 89: -5.062878214511315, 90: -11.02462566385297, 91: -12.265613551943261, 92: -13.855648206100362, 93: -14.933092020258243, 94: -15.282826131998245}', energy_key='energy', forces_key='forces', virials_key='virials', stress_key='stress', dipole_key='dipole', charges_key='charges', loss='universal', forces_weight=10.0, swa_forces_weight=100.0, energy_weight=1.0, swa_energy_weight=1000.0, virials_weight=1.0, swa_virials_weight=10.0, stress_weight=100.0, swa_stress_weight=10.0, dipole_weight=1.0, swa_dipole_weight=1.0, config_type_weights='{"Default":1.0}', huber_delta=0.01, optimizer='adam', batch_size=2, valid_batch_size=4, lr=0.005, swa_lr=0.001, weight_decay=1e-08, amsgrad=True, scheduler='ReduceLROnPlateau', lr_factor=0.8, scheduler_patience=5, lr_scheduler_gamma=0.9993, swa=False, start_swa=None, ema=True, ema_decay=0.995, max_num_epochs=300, patience=30, foundation_model='2024-01-07-mace-128-L2.model', foundation_model_readout=True, eval_interval=1, keep_checkpoints=True, restart_latest=True, save_cpu=True, clip_grad=100.0, wandb=False, wandb_project='', wandb_entity='', wandb_name='', wandb_log_hypers=['num_channels', 'max_L', 'correlation', 'lr', 'swa_lr', 'weight_decay', 'batch_size', 'max_num_epochs', 'start_swa', 'energy_weight', 'forces_weight'])
2024-02-19 16:15:27.814 INFO: CUDA version: 11.3, CUDA device: 0
2024-02-19 16:15:27.920 INFO: Loaded 89 training configurations from '/home/bruno.focassio/mace_large_model/uip/fine_tune_2024_example/training_data.xyz'
2024-02-19 16:15:27.921 INFO: Using random 5.0% of training set for validation
2024-02-19 16:15:27.957 INFO: Loaded 89 test configurations from '/home/bruno.focassio/mace_large_model/uip/fine_tune_2024_example/test_data.xyz'
2024-02-19 16:15:27.957 INFO: Total number of configurations: train=85, valid=4, tests=[Default: 89]
2024-02-19 16:15:27.958 INFO: AtomicNumberTable: (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 89, 90, 91, 92, 93, 94)
2024-02-19 16:15:27.958 INFO: Atomic Energies not in training file, using command line argument E0s
2024-02-19 16:15:27.959 INFO: Atomic energies: [-3.667168021358939, -1.3320953124042916, -3.482100566595956, -4.736697230897597, -7.724935420523256, -8.405573550273285, -7.360100452662763, -7.28459863421322, -4.896490881731322, 1.3917755836700962e-12, -2.7593613569762425, -2.814047612069227, -4.846881245288104, -7.694793133351899, -6.9632957911820235, -4.672630400190884, -2.8116892814008096, -0.06259504416367478, -2.6176454856894793, -5.390461060484104, -7.8857952163517675, -10.268392986214433, -8.665147785496703, -9.233050763772013, -8.304951520770791, -7.0489865771593765, -5.577439766222147, -5.172747618813715, -3.2520726958619472, -1.2901611618726314, -3.527082192997912, -4.70845955030298, -3.9765109025623238, -3.886231055836541, -2.5184940099633986, 6.766947645687137, -2.5634958965928316, -4.938005211501922, -10.149818838085771, -11.846857579882572, -12.138896361658485, -8.791678800595722, -8.78694939675911, -7.78093221529871, -6.850021409115055, -4.891019073240479, -2.0634296773864045, -0.6395695518943755, -2.7887442084286693, -3.818604275441892, -3.587068329278862, -2.8804045971118897, -1.6355986842433357, 9.846723842807721, -2.765284507132287, -4.990956432167774, -8.933684809576345, -8.735591176647514, -8.018966025544966, -8.251491970213372, -7.591719594359237, -8.169659881166858, -13.592664636171698, -18.517523458456985, -7.647396572993602, -8.122981037851925, -7.607787319678067, -6.85029094445494, -7.8268821327130365, -3.584786591677161, -7.455406192077973, -12.796283502572146, -14.108127281277586, -9.354916969477486, -11.387537567890853, -9.621909492152557, -7.324393429417677, -5.3046964808341945, -2.380092582080244, 0.24948924158195362, -2.3239789120665026, -3.730042357127322, -3.438792347649683, -5.062878214511315, -11.02462566385297, -12.265613551943261, -13.855648206100362, -14.933092020258243, -15.282826131998245]
2024-02-19 16:15:28.078 INFO: UniversalLoss(energy_weight=1.000, forces_weight=10.000, stress_weight=100.000)
2024-02-19 16:15:28.127 INFO: Average number of neighbors: nan
2024-02-19 16:15:28.128 INFO: Selected the following outputs: {'energy': True, 'forces': True, 'virials': False, 'stress': True, 'dipoles': False}
2024-02-19 16:15:28.128 INFO: Building model
2024-02-19 16:15:28.130 INFO: Hidden irreps: 128x0e+128x1o+128x2e
2024-02-19 16:15:28.289 WARNING: Standard deviation of the scaling is zero, Changing to no scaling
2024-02-19 16:15:31.924 INFO: Using foundation model 2024-01-07-mace-128-L2.model as initial checkpoint.
Traceback (most recent call last):
  File "/home/bruno.focassio/codes/mace-foundations/scripts/run_train.py", line 6, in <module>
    main()
  File "/home/bruno.focassio/anaconda3/envs/mace_foundations/lib/python3.10/site-packages/mace/cli/run_train.py", line 398, in main
    model = load_foundations(
  File "/home/bruno.focassio/anaconda3/envs/mace_foundations/lib/python3.10/site-packages/mace/tools/utils.py", line 173, in load_foundations
    indices_weights = [z_table.z_to_index(z) for z in new_z_table.zs]
  File "/home/bruno.focassio/anaconda3/envs/mace_foundations/lib/python3.10/site-packages/mace/tools/utils.py", line 173, in <listcomp>
    indices_weights = [z_table.z_to_index(z) for z in new_z_table.zs]
  File "/home/bruno.focassio/anaconda3/envs/mace_foundations/lib/python3.10/site-packages/mace/tools/utils.py", line 89, in z_to_index
    return self.zs.index(atomic_number)
ValueError: 1 is not in list

Attached there is the dummy train and test files. The energies were predicted with the mace calculator from this model.

dummy_data_files.zip

Hi again @ilyes319

I realized the error above was caused by a model that was overwritten with a different table of atomic numbers. I downloaded the 2024-01-07-mace-128-L2.model again and it gets past that error.

Now the trouble is, as I mentioned before, a shape mismatch. Can you please take a look?

python /home/bruno.focassio/codes/mace-foundations/scripts/run_train.py \
    --name="2024-01-07-mace-128-L2" \
    --foundation_model="2024-01-07-mace-128-L2.model" \
    --train_file="training_data.xyz" \
    --test_file="test_data.xyz" \
    --valid_fraction=0.05 \
    --loss="universal" \
    --energy_weight=1 \
    --forces_weight=10 \
    --compute_stress=True \
    --stress_weight=100 \
    --stress_key='stress' \
    --eval_interval=1 \
    --error_table='PerAtomMAE' \
    --E0s="{1: -3.667168021358939, 2: -1.3320953124042916, 3: -3.482100566595956, 4: -4.736697230897597, 5: -7.724935420523256, 6: -8.405573550273285, 7: -7.360100452662763, 8: -7.28459863421322, 9: -4.896490881731322, 10: 1.3917755836700962e-12, 11: -2.7593613569762425, 12: -2.814047612069227, 13: -4.846881245288104, 14: -7.694793133351899, 15: -6.9632957911820235, 16: -4.672630400190884, 17: -2.8116892814008096, 18: -0.06259504416367478, 19: -2.6176454856894793, 20: -5.390461060484104, 21: -7.8857952163517675, 22: -10.268392986214433, 23: -8.665147785496703, 24: -9.233050763772013, 25: -8.304951520770791, 26: -7.0489865771593765, 27: -5.577439766222147, 28: -5.172747618813715, 29: -3.2520726958619472, 30: -1.2901611618726314, 31: -3.527082192997912, 32: -4.70845955030298, 33: -3.9765109025623238, 34: -3.886231055836541, 35: -2.5184940099633986, 36: 6.766947645687137, 37: -2.5634958965928316, 38: -4.938005211501922, 39: -10.149818838085771, 40: -11.846857579882572, 41: -12.138896361658485, 42: -8.791678800595722, 43: -8.78694939675911, 44: -7.78093221529871, 45: -6.850021409115055, 46: -4.891019073240479, 47: -2.0634296773864045, 48: -0.6395695518943755, 49: -2.7887442084286693, 50: -3.818604275441892, 51: -3.587068329278862, 52: -2.8804045971118897, 53: -1.6355986842433357, 54: 9.846723842807721, 55: -2.765284507132287, 56: -4.990956432167774, 57: -8.933684809576345, 58: -8.735591176647514, 59: -8.018966025544966, 60: -8.251491970213372, 61: -7.591719594359237, 62: -8.169659881166858, 63: -13.592664636171698, 64: -18.517523458456985, 65: -7.647396572993602, 66: -8.122981037851925, 67: -7.607787319678067, 68: -6.85029094445494, 69: -7.8268821327130365, 70: -3.584786591677161, 71: -7.455406192077973, 72: -12.796283502572146, 73: -14.108127281277586, 74: -9.354916969477486, 75: -11.387537567890853, 76: -9.621909492152557, 77: -7.324393429417677, 78: -5.3046964808341945, 79: -2.380092582080244, 80: 0.24948924158195362, 81: -2.3239789120665026, 82: -3.730042357127322, 83: -3.438792347649683, 89: -5.062878214511315, 90: -11.02462566385297, 91: -12.265613551943261, 92: -13.855648206100362, 93: -14.933092020258243, 94: -15.282826131998245}" \
    --interaction_first="RealAgnosticResidualInteractionBlock" \
    --interaction="RealAgnosticResidualInteractionBlock" \
    --num_interactions=2 \
    --correlation=3 \
    --max_ell=3 \
    --r_max=6.0 \
    --max_L=2 \
    --num_channels=128 \
    --num_radial_basis=10 \
    --MLP_irreps="16x0e" \
    --scaling='rms_forces_scaling' \
    --lr=0.005 \
    --weight_decay=1e-8 \
    --ema \
    --ema_decay=0.995 \
    --scheduler_patience=5 \
    --batch_size=2 \
    --valid_batch_size=4 \
    --max_num_epochs=300 \
    --patience=30 \
    --amsgrad \
    --device="cuda" \
    --default_dtype="float64" \
    --seed=1 \
    --clip_grad=100 \
    --keep_checkpoints \
    --restart_latest \
    --save_cpu

And log:

2024-02-20 08:42:06.620 INFO: MACE version: 0.3.4
2024-02-20 08:42:06.621 INFO: Configuration: Namespace(name='2024-01-07-mace-128-L2', seed=1, log_dir='logs', model_dir='.', checkpoints_dir='checkpoints', results_dir='results', downloads_dir='downloads', device='cuda', default_dtype='float64', log_level='INFO', error_table='PerAtomMAE', model='MACE', r_max=6.0, radial_type='bessel', num_radial_basis=10, num_cutoff_basis=5, pair_repulsion=False, distance_transform=False, interaction='RealAgnosticResidualInteractionBlock', interaction_first='RealAgnosticResidualInteractionBlock', max_ell=3, correlation=3, num_interactions=2, MLP_irreps='16x0e', radial_MLP='[64, 64, 64]', hidden_irreps='128x0e + 128x1o', num_channels=128, max_L=2, gate='silu', scaling='rms_forces_scaling', avg_num_neighbors=1, compute_avg_num_neighbors=True, compute_stress=True, compute_forces=True, train_file='/home/bruno.focassio/mace_large_model/uip/train_2024/training_data.xyz', valid_file=None, valid_fraction=0.05, test_file='/home/bruno.focassio/mace_large_model/uip/train_2024/test_data.xyz', E0s='{1: -3.667168021358939, 2: -1.3320953124042916, 3: -3.482100566595956, 4: -4.736697230897597, 5: -7.724935420523256, 6: -8.405573550273285, 7: -7.360100452662763, 8: -7.28459863421322, 9: -4.896490881731322, 10: 1.3917755836700962e-12, 11: -2.7593613569762425, 12: -2.814047612069227, 13: -4.846881245288104, 14: -7.694793133351899, 15: -6.9632957911820235, 16: -4.672630400190884, 17: -2.8116892814008096, 18: -0.06259504416367478, 19: -2.6176454856894793, 20: -5.390461060484104, 21: -7.8857952163517675, 22: -10.268392986214433, 23: -8.665147785496703, 24: -9.233050763772013, 25: -8.304951520770791, 26: -7.0489865771593765, 27: -5.577439766222147, 28: -5.172747618813715, 29: -3.2520726958619472, 30: -1.2901611618726314, 31: -3.527082192997912, 32: -4.70845955030298, 33: -3.9765109025623238, 34: -3.886231055836541, 35: -2.5184940099633986, 36: 6.766947645687137, 37: -2.5634958965928316, 38: -4.938005211501922, 39: -10.149818838085771, 40: -11.846857579882572, 41: -12.138896361658485, 42: -8.791678800595722, 43: -8.78694939675911, 44: -7.78093221529871, 45: -6.850021409115055, 46: -4.891019073240479, 47: -2.0634296773864045, 48: -0.6395695518943755, 49: -2.7887442084286693, 50: -3.818604275441892, 51: -3.587068329278862, 52: -2.8804045971118897, 53: -1.6355986842433357, 54: 9.846723842807721, 55: -2.765284507132287, 56: -4.990956432167774, 57: -8.933684809576345, 58: -8.735591176647514, 59: -8.018966025544966, 60: -8.251491970213372, 61: -7.591719594359237, 62: -8.169659881166858, 63: -13.592664636171698, 64: -18.517523458456985, 65: -7.647396572993602, 66: -8.122981037851925, 67: -7.607787319678067, 68: -6.85029094445494, 69: -7.8268821327130365, 70: -3.584786591677161, 71: -7.455406192077973, 72: -12.796283502572146, 73: -14.108127281277586, 74: -9.354916969477486, 75: -11.387537567890853, 76: -9.621909492152557, 77: -7.324393429417677, 78: -5.3046964808341945, 79: -2.380092582080244, 80: 0.24948924158195362, 81: -2.3239789120665026, 82: -3.730042357127322, 83: -3.438792347649683, 89: -5.062878214511315, 90: -11.02462566385297, 91: -12.265613551943261, 92: -13.855648206100362, 93: -14.933092020258243, 94: -15.282826131998245}', energy_key='energy', forces_key='forces', virials_key='virials', stress_key='stress', dipole_key='dipole', charges_key='charges', loss='universal', forces_weight=10.0, swa_forces_weight=100.0, energy_weight=1.0, swa_energy_weight=1000.0, virials_weight=1.0, swa_virials_weight=10.0, stress_weight=100.0, swa_stress_weight=10.0, dipole_weight=1.0, swa_dipole_weight=1.0, config_type_weights='{"Default":1.0}', huber_delta=0.01, optimizer='adam', batch_size=2, valid_batch_size=4, lr=0.005, swa_lr=0.001, weight_decay=1e-08, amsgrad=True, scheduler='ReduceLROnPlateau', lr_factor=0.8, scheduler_patience=5, lr_scheduler_gamma=0.9993, swa=False, start_swa=None, ema=True, ema_decay=0.995, max_num_epochs=300, patience=30, foundation_model='2024-01-07-mace-128-L2.model', foundation_model_readout=True, eval_interval=1, keep_checkpoints=True, restart_latest=True, save_cpu=True, clip_grad=100.0, wandb=False, wandb_project='', wandb_entity='', wandb_name='', wandb_log_hypers=['num_channels', 'max_L', 'correlation', 'lr', 'swa_lr', 'weight_decay', 'batch_size', 'max_num_epochs', 'start_swa', 'energy_weight', 'forces_weight'])
2024-02-20 08:42:06.654 INFO: CUDA version: 11.3, CUDA device: 0
2024-02-20 08:42:07.619 INFO: Loaded 1491 training configurations from '/home/bruno.focassio/mace_large_model/uip/train_2024/training_data.xyz'
2024-02-20 08:42:07.620 INFO: Using random 5.0% of training set for validation
2024-02-20 08:42:07.739 INFO: Loaded 253 test configurations from '/home/bruno.focassio/mace_large_model/uip/train_2024/test_data.xyz'
2024-02-20 08:42:07.740 INFO: Total number of configurations: train=1417, valid=74, tests=[Default: 253]
2024-02-20 08:42:07.756 INFO: AtomicNumberTable: (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 89, 90, 91, 92, 93, 94)
2024-02-20 08:42:07.756 INFO: Atomic Energies not in training file, using command line argument E0s
2024-02-20 08:42:07.757 INFO: Atomic energies: [-3.667168021358939, -1.3320953124042916, -3.482100566595956, -4.736697230897597, -7.724935420523256, -8.405573550273285, -7.360100452662763, -7.28459863421322, -4.896490881731322, 1.3917755836700962e-12, -2.7593613569762425, -2.814047612069227, -4.846881245288104, -7.694793133351899, -6.9632957911820235, -4.672630400190884, -2.8116892814008096, -0.06259504416367478, -2.6176454856894793, -5.390461060484104, -7.8857952163517675, -10.268392986214433, -8.665147785496703, -9.233050763772013, -8.304951520770791, -7.0489865771593765, -5.577439766222147, -5.172747618813715, -3.2520726958619472, -1.2901611618726314, -3.527082192997912, -4.70845955030298, -3.9765109025623238, -3.886231055836541, -2.5184940099633986, 6.766947645687137, -2.5634958965928316, -4.938005211501922, -10.149818838085771, -11.846857579882572, -12.138896361658485, -8.791678800595722, -8.78694939675911, -7.78093221529871, -6.850021409115055, -4.891019073240479, -2.0634296773864045, -0.6395695518943755, -2.7887442084286693, -3.818604275441892, -3.587068329278862, -2.8804045971118897, -1.6355986842433357, 9.846723842807721, -2.765284507132287, -4.990956432167774, -8.933684809576345, -8.735591176647514, -8.018966025544966, -8.251491970213372, -7.591719594359237, -8.169659881166858, -13.592664636171698, -18.517523458456985, -7.647396572993602, -8.122981037851925, -7.607787319678067, -6.85029094445494, -7.8268821327130365, -3.584786591677161, -7.455406192077973, -12.796283502572146, -14.108127281277586, -9.354916969477486, -11.387537567890853, -9.621909492152557, -7.324393429417677, -5.3046964808341945, -2.380092582080244, 0.24948924158195362, -2.3239789120665026, -3.730042357127322, -3.438792347649683, -5.062878214511315, -11.02462566385297, -12.265613551943261, -13.855648206100362, -14.933092020258243, -15.282826131998245]
2024-02-20 08:42:14.706 INFO: UniversalLoss(energy_weight=1.000, forces_weight=10.000, stress_weight=100.000)
2024-02-20 08:42:15.579 INFO: Average number of neighbors: 62.46425785800387
2024-02-20 08:42:15.580 INFO: Selected the following outputs: {'energy': True, 'forces': True, 'virials': False, 'stress': True, 'dipoles': False}
2024-02-20 08:42:15.580 INFO: Building model
2024-02-20 08:42:15.588 INFO: Hidden irreps: 128x0e+128x1o+128x2e
2024-02-20 08:42:21.455 INFO: Using foundation model 2024-01-07-mace-128-L2.model as initial checkpoint.
2024-02-20 08:42:21.523 WARNING: No SWA checkpoint found, while SWA is enabled. Compare the swa_start parameter and the latest checkpoint.
2024-02-20 08:42:21.524 INFO: Loading checkpoint: checkpoints/2024-01-07-mace-128-L2_run-1_epoch-199.pt
Traceback (most recent call last):
  File "/home/bruno.focassio/anaconda3/envs/mace_foundations/lib/python3.10/site-packages/mace/cli/run_train.py", line 525, in main
    opt_start_epoch = checkpoint_handler.load_latest(
  File "/home/bruno.focassio/anaconda3/envs/mace_foundations/lib/python3.10/site-packages/mace/tools/checkpoint.py", line 210, in load_latest
    result = self.io.load_latest(swa=swa, device=device)
  File "/home/bruno.focassio/anaconda3/envs/mace_foundations/lib/python3.10/site-packages/mace/tools/checkpoint.py", line 171, in load_latest
    path = self._get_latest_checkpoint_path(swa=swa)
  File "/home/bruno.focassio/anaconda3/envs/mace_foundations/lib/python3.10/site-packages/mace/tools/checkpoint.py", line 152, in _get_latest_checkpoint_path
    return latest_checkpoint_info.path
UnboundLocalError: local variable 'latest_checkpoint_info' referenced before assignment

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/bruno.focassio/codes/mace-foundations/scripts/run_train.py", line 6, in <module>
    main()
  File "/home/bruno.focassio/anaconda3/envs/mace_foundations/lib/python3.10/site-packages/mace/cli/run_train.py", line 531, in main
    opt_start_epoch = checkpoint_handler.load_latest(
  File "/home/bruno.focassio/anaconda3/envs/mace_foundations/lib/python3.10/site-packages/mace/tools/checkpoint.py", line 215, in load_latest
    self.builder.load_checkpoint(state=state, checkpoint=checkpoint, strict=strict)
  File "/home/bruno.focassio/anaconda3/envs/mace_foundations/lib/python3.10/site-packages/mace/tools/checkpoint.py", line 40, in load_checkpoint
    state.model.load_state_dict(checkpoint["model"], strict=strict)  # type: ignore
  File "/home/bruno.focassio/anaconda3/envs/mace_foundations/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1604, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ScaleShiftMACE:
	size mismatch for interactions.0.skip_tp.weight: copying a param with shape torch.Size([1458176]) from checkpoint, the shape in current model is torch.Size([5832704]).
	size mismatch for interactions.0.skip_tp.output_mask: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([2048]).
	size mismatch for interactions.1.linear_up.weight: copying a param with shape torch.Size([16384]) from checkpoint, the shape in current model is torch.Size([49152]).
	size mismatch for interactions.1.linear_up.output_mask: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1152]).
	size mismatch for interactions.1.conv_tp.output_mask: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([9088]).
	size mismatch for interactions.1.conv_tp_weights.layer3.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([64, 2176]).
	size mismatch for interactions.1.linear.weight: copying a param with shape torch.Size([65536]) from checkpoint, the shape in current model is torch.Size([278528]).
	size mismatch for products.0.linear.weight: copying a param with shape torch.Size([16384]) from checkpoint, the shape in current model is torch.Size([49152]).
	size mismatch for products.0.linear.output_mask: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1152]).

I see here you have a checkpoint named : INFO: Loading checkpoint: checkpoints/2024-01-07-mace-128-L2_run-1_epoch-199.pt
Can you please change your run name so it does not load this checkpoint.

Indeed not loading the checkpoint from the pt file works. However, I was actually trying to load the checkpoint available on hugging-face: https://huggingface.co/cyrusyc/mace-universal/tree/main/pretrained
If the checkpoint was generated from the model training, shouldn't it be compatible to continue training?

however if you do that you will need to give the full hypers of the model

I there technical reason why it has to be this way? Could you come up with a way to save the hypers with the model, so this kind of problem happens less?

It would be ideal to save all the hyper params in a kwargs dictionary and have flexible parsing like MACE(**kwargs). However it seems pretty hard for the current way of implementation... and the train cli and argparser are still evolving...

I am curious about what is the founction of --fundation_model_readout, does it greatly influence the finetuning result?