[Meta Issue] Wang 2022 Extension: Cross evaluation by changing model equivariance (alpha) and dataset equivariance
dgcnz opened this issue · comments
Description
This issue concerns the extension of Wang 2022 figure 4 which consists on fixing a dataset's equivariance (levels = [full, partial]) and training a relaxed equivariant model with k
different alpha
.
So we have a total of 2 * k
runs per model, where 2 represents the number of datasets and k
represents the number of different alpha
tested.
Configs
SmokePlume configs:
- Figure 4 Equivariance Test:
configs/data/wang2022/equivariance_test.yaml
(✅) (see Q1)
Model default configs (DO NOT MODIFY THESE FILES, other experiment files rely on the defaults set here):
- ConvNet:
configs/model/wang2022/convnet.yaml
(✅) - RGroup:
configs/model/wang2022/rgroup.yaml
(❓) - RSteer:
configs/model/wang2022/rsteer.yaml
(✅)
Experiment configs:
- ConvNet:
configs/experiment/wang2022/equivariance_test/convnet.yaml
(✅) - RGroup:
configs/experiment/wang2022/equivariance_test/rgroup.yaml
(❓) - RSteer:
configs/experiment/wang2022/equivariance_test/rsteer.yaml
(✅)
These configs have to be at least tested with trainer.fast_dev_run
to ensure that the model even processes data correctly. This doesn't account for model checkpointing and early stopping, so we'll have to add tests to that. Examples can be found in the Makefile's command test_wang2022_figure_4
which you can run make test_wang2022_figure_4
.
Example testing command:
python -m src.train experiment=wang2022/equivariance_test/rgroup +trainer.fast_dev_run=True data.batch_size=8
Legend:
- (❓): Educated guesses were made to get the parameters
- (✅): Author's code had full reproducibility
Tasks
- Implement script that takes as input (
equivariance_level
,alpha
,model
) and runs the corresponding experiment - Implement script that takes all desired combinations
equivariance_level
,alpha
andmodel
and launches all SLURM jobs (if snellius goes back on time)
Questions
- Q1: Should we use the
10
remaining time steps for testing instead of testing on the validation?
Feel free to add more questions or tasks
Example usage of downloading and loading from wandb.
import wandb
from src.utils.wandb import get_all_checkpoints, download_artifact
from src.models.wang2022_module import Wang2022LightningModule
from pathlib import Path
# get all checkpoints from run id
entity = "uva-dl2"
project = "wang2022"
run_id = "bpiojh4w"
checkpoints = get_all_checkpoints(run_id, project, entity)
print(checkpoints)
with wandb.init(project=project, entity=entity, job_type="run-evaluation-test") as run:
artifact_dir = download_artifact(run, checkpoints[0], project, entity)
model = Wang2022LightningModule.load_from_checkpoint(Path(artifact_dir) / "model.ckpt")
print(model)
done