dgcnz / relaxed-equivariance-dynamics

Code for "Effect of equivariance on training dynamics"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Meta Issue] Wang 2022 Extension: Cross evaluation by changing model equivariance (alpha) and dataset equivariance

dgcnz opened this issue · comments

Description

This issue concerns the extension of Wang 2022 figure 4 which consists on fixing a dataset's equivariance (levels = [full, partial]) and training a relaxed equivariant model with k different alpha.

So we have a total of 2 * k runs per model, where 2 represents the number of datasets and k represents the number of different alpha tested.

Configs

SmokePlume configs:

  • Figure 4 Equivariance Test: configs/data/wang2022/equivariance_test.yaml (✅) (see Q1)

Model default configs (DO NOT MODIFY THESE FILES, other experiment files rely on the defaults set here):

  • ConvNet: configs/model/wang2022/convnet.yaml (✅)
  • RGroup: configs/model/wang2022/rgroup.yaml (❓)
  • RSteer: configs/model/wang2022/rsteer.yaml (✅)

Experiment configs:

  • ConvNet: configs/experiment/wang2022/equivariance_test/convnet.yaml (✅)
  • RGroup: configs/experiment/wang2022/equivariance_test/rgroup.yaml (❓)
  • RSteer: configs/experiment/wang2022/equivariance_test/rsteer.yaml (✅)

These configs have to be at least tested with trainer.fast_dev_run to ensure that the model even processes data correctly. This doesn't account for model checkpointing and early stopping, so we'll have to add tests to that. Examples can be found in the Makefile's command test_wang2022_figure_4 which you can run make test_wang2022_figure_4.

Example testing command:

python -m src.train experiment=wang2022/equivariance_test/rgroup +trainer.fast_dev_run=True data.batch_size=8

Legend:

  • (❓): Educated guesses were made to get the parameters
  • (✅): Author's code had full reproducibility

Tasks

  • Implement script that takes as input (equivariance_level, alpha, model) and runs the corresponding experiment
  • Implement script that takes all desired combinations equivariance_level, alpha and model and launches all SLURM jobs (if snellius goes back on time)

Questions

  • Q1: Should we use the 10 remaining time steps for testing instead of testing on the validation?

Feel free to add more questions or tasks

Example usage of downloading and loading from wandb.

import wandb
from src.utils.wandb import get_all_checkpoints, download_artifact
from src.models.wang2022_module import Wang2022LightningModule
from pathlib import Path

# get all checkpoints from run id
entity = "uva-dl2"
project = "wang2022"
run_id = "bpiojh4w"

checkpoints = get_all_checkpoints(run_id, project, entity)
print(checkpoints)
with wandb.init(project=project, entity=entity, job_type="run-evaluation-test") as run:
    artifact_dir = download_artifact(run, checkpoints[0], project, entity)
    model = Wang2022LightningModule.load_from_checkpoint(Path(artifact_dir) / "model.ckpt")
    print(model)

done