The goal of this project is to achieve zero-code, from-configuration-only training of PyTorch models using PyTorch Lightning. This is achieved by using a configuration dictionary that specifies the model, the dataset, the data loaders, etc. The configuration is then used to build all required objects. Currently, this leads to an average lines-of-code reduction of 15% compared to a standard PyTorch Lightning, while improving customizability + reproducibility and maintaining the same flexibility as the original code.
To install the package, you can use the following command:
pip install git+https://github.com/V0XNIHILI/autolightning.git@main
Or, if you want to install the package in editable mode, you can use the following command:
git clone git@github.com:V0XNIHILI/autolightning.git
cd autolightning
# Make sure you have pip 23 or higher
pip install -e .
To define a complete configuration, you can use the following top-level keys:
cfg = {
"learner": {...},
"criterion": {...},
"lr_scheduler": {...}, # Optional
"model": {...},
"optimizer": {...},
"training": {...}, # Optional
"seed": ..., # Optional
"dataset": {...}, # Optional
"dataloaders": {...}, # Optional
}
For example, to train a LeNet5 on MNIST with early stopping and learning rate stepping, the configuration can be defined in one of the following ways:
Note that I use DotMap
here to define the configuration, but you can use any other dictionary-like object, or use tools like OmegaConf or Hydra to define the configuration.
# filename: main.py
# -----------------
from dotmap import DotMap
cfg = DotMap()
# Specify the learner and its configuration
# Without having any dots in the name, the import will be done from
# the `autolightning.lm` module
cfg.learner.name = "SupervisedLearner"
cfg.learner.cfg = {
# Indicate whether classification accuracy should be computed
"classification": True,
# Optionally specify for which ks the top-k accuracy should be computed
# "topk": [1, 5],
}
# Select the criterion and its configuration
cfg.criterion.name = 'CrossEntropyLoss'
# Optionally specify the configuration for the criterion:
# cfg.criterion.cfg = {
# 'reduction': 'mean'
# }
# Optionally specify the learning rate scheduler and its configuration
cfg.lr_scheduler.scheduler = {
"name": "StepLR",
"cfg": {
"step_size": 2,
"verbose": True,
},
}
# Specify the model and its configuration
cfg.model.name = 'torch_mate.models.LeNet5BNMaxPool'
cfg.model.cfg.num_classes = 10
# Optionally specify a compilation configuration
cfg.model.extra.compile = {
'name': 'torch.compile'
}
# Specify the optimizer and its configuration
cfg.optimizer.name = 'Adam'
cfg.optimizer.cfg = {"lr": 0.007}
# Specify the training configuration (passed directly to the PyTorch
# Lightning Trainer). The `early_stopping` configuration is optional
# and will be used to configure the early stopping callback.
cfg.training = {
'max_epochs': 100,
'early_stopping': {
'monitor': 'val/loss',
'patience': 10,
'mode': 'min'
},
}
# Set the seed for reproducibility
cfg.seed = 4223747124
# Specify the dataset and its configuration.
# Without having any dots in the name, the import will be done from
# the `autolightning.datasets` module
cfg.dataset.name = 'MagicData'
cfg.dataset.cfg = {
"name": "MNIST", # Can also be torchvision.datasets.MNIST for example
"val_percentage": 0.1
}
cfg.dataset.kwargs = {
"root": './data',
"download": True
}
# Specify the transforms and their configuration
# Note that you can specify .pre (common pre-transform), .train
# .val/.test/.predict (specific transforms for each split) and
# .post (common post-transform). The complete transforms will
# then be built automatically. The same goes for target_transforms
# via: cfg.dataset.target_transforms
cfg.dataset.transforms.pre = [
{'name': 'ToTensor'},
{'name': 'Resize', 'cfg': {'size': (28, 28)}},
]
# Optionally, specify a pre-device and post-device transfer
# batch transform via: cfg.dataset.batch_transforms.pre and
# cfg.dataset.batch_transforms.post in the same manner
# as for the other transforms.
# Specify the data loaders and their configuration (where default
# is the fallback configuration for all data loaders)
cfg.dataloaders = {
'default': {
'num_workers': 4,
'prefetch_factor': 16,
'persistent_workers': True,
'batch_size': 256,
},
'train': {
'batch_size': 512
}
}
Note that the configuration can also contain references to classes directly, without the relative import path. This is practical for example when you define a model class in the same file as the configuration. For example:
class LeNet5BNMaxPool(nn.Module):
def __init__(self, num_classes: int):
super(LeNet5BNMaxPool, self).__init__()
...
def forward(self, x):
...
cfg.model.name = LeNet5BNMaxPool
Finally, serialize the resulting configuration to a dictionary:
cfg = cfg.toDict()
# TO DO!
# TO DO!
# TO DO!
Use config_all
to configure the model, data and trainer in one go. This function returns the trainer, model and data objects, which can be used to train the model. Alternatively, you can also use config_model
, config_data
and config_model_data
to only configure specific parts.
from lightning.pytorch.loggers import WandbLogger
from autolightning import config_all
trainer, model, data = config_all(cfg,
# Specify all keyworded arguments that are not part of the
# `cfg.training` dictionary for the PyTorch Lightning Trainer
{
"enable_progress_bar": True,
"accelerator": "mps",
"devices": 1,
"logger": WandbLogger(project="test_wandb_lightning")
}
)
By creating the objects manually, you will have more flexibility and can decide which objects are created with autolightning and which you create by yourself. For example, in this way, you can combine a custom model configured via autolightning with a regular PyTorch dataloader.
from autolightning.lm import SupervisedLearner
from autolightning.datasets import MagicData
from lightning import Trainer
# Create the model, data and trainer
model = SupervisedLearner(cfg)
data = MagicData(cfg)
trainer = Trainer(**cfg["training"])
You can immediately continue with step 3 - train the model.
Note thattThe autolightning
CLI tool supports all the same arguments as the regular PyTorch Lightning CLI (as the AutoCLI
is a subclass of the LightningCLI
) but allows for two key differences:
- Configurations can also be specified as Python files (instead of in YAML files)
- The AutoCLI has additional
torch
flags that can be set in a configuration file to configure the PyTorch backend regarding debugging and performance. For example:... torch: autograd: set_detect_anomaly: False profiler: profile: False emit_nvtx: False set_float32_matmul_precision: high backends: cuda: matmul: allow_tf32: True cudnn: allow_tf32: True benchmark: True
You can also split hyperparameter configuration from per-machine specific configuration by moving the latter into a separate (YAML) config file, for example:
# filename: local.yaml
# --------------------
data:
# Don't forget to remove these keys from main.py!
root: "../datasets/data"
download: True
trainer:
logger:
- class_path: WandbLogger
init_args:
project: name_of_wandb_project
log_model: true
callbacks:
- class_path: ModelCheckpoint
init_args:
monitor: val/accuracy
mode: max
save_on_train_epoch_end: false
save_top_k: 1
accelerator: gpu
check_val_every_n_epoch: 10
devices: [3]
To run training on the combined configuration (where only the values in main.py
are stored as hyperparemeters):
autolightning fit -c main.py -c local.yaml
Option 4.1: Only enable trainer configuration + data/model kwargs from CLI
This way, you only still have to provide the trainer configuration (via the --config
flag) to the CLI, which often contains environment-specific settings like GPU indices, etc. while keeping experiment-specific settings fixed. To get more information on how this can be done, see here for a crisp overview of the PyTorch Lightning CLI.
# file: main.py
from autolightning import pre_cli
from autolightning.lm import SupervisedLearner
from autolightning.datasets import MagicData
from lightning.pytorch.cli import LightningCLI
def cli_main():
cfg = ... # Load or set the configuration in any way you want
LightningCLI(
pre_cli(SupervisedLearner, cfg), # All arguments after cfg are available to be set in the CLI
pre_cli(MagicData, cfg), # Same goes for the data module
trainer_defaults=cfg["training"] if "training" in cfg else None,
seed_everything_default=cfg["seed"] if "seed" in cfg else True,
)
if __name__ == "__main__":
cli_main()
An example configuration for the trainer in this case could be:
# file: config.yaml
trainer:
logger:
- class_path: WandbLogger
init_args:
project: test_autolightning
callbacks:
- class_path: ModelCheckpoint
init_args:
dirpath: ./nets
monitor: val/accuracy
save_top_k: 1
accelerator: gpu
check_val_every_n_epoch: 1
log_every_n_steps: 20
data:
root: ./data
Option 4.2 (not recommended): Enable model, data and trainer configuration from CLI
In this way, you store all the training, model and data configuration in one file. However, to stay consistent with the original Lightning CLI API, we use variable interpolation to avoid duplicate values in the YAML file (to enable this, we set parser_kwargs={"parser_mode": "omegaconf"}
).
# file: main.py
from autolightning import ConfigurableLightningModule, ConfigurableLightningDataModule
from lightning.pytorch.cli import LightningCLI
def cli_main():
LightningCLI(
ConfigurableLightningModule,
ConfigurableLightningDataModule,
subclass_mode_model=True,
subclass_mode_data=True,
parser_kwargs={"parser_mode": "omegaconf"}
)
if __name__ == "__main__":
cli_main()
# file: config.yaml
trainer:
max_epochs: ${model.init_args.cfg.training.max_epochs}
...
model:
class_path: ${model.init_args.cfg.learner.name}
init_args:
# All variables in the this cfg variable below will be saved as hyperparameters,
# and can be accessed in the model via self.hparams. None of the other variables
# in this file will be saved as hyperparameters.
cfg:
criterion:
dataloaders:
dataset:
name: your_module.YourDataModule
learner:
name: autolightning.lm.SupervisedLearner
lr_scheduler:
model:
optimizer:
seed: 4223747124
training:
max_epochs: 100
...
data:
class_path: ${model.init_args.cfg.dataset.name}
init_args:
cfg: ${model.init_args.cfg}
...
seed_everything: ${model.init_args.cfg.seed}
If you have the model, data and trainer instantiated, you can train the model using the following code:
trainer.fit(model, data)
Or if you want to use the AutoLightning CLI, you can run the following command:
autolighting fit --config main.py --config local.yaml
Finally, in case you used the original PyTorch Lightning CLI in your own code, you can run the following command:
python main.py fit --config ./config.yaml
In case you want to add or override behavior of the defaults selected by autolightning, this can be done by using hooks. autolightning adds a few new hooks, next to the ones provided by PyTorch Lightning:
configure_configuration(self, cfg: Dict)
- Return the configuration that should be used. This configuration can be accessed at
self.hparams
.
- Return the configuration that should be used. This configuration can be accessed at
config_model(self)
- Return the model that should be trained. This model can be access with
get_model(self)
.
- Return the model that should be trained. This model can be access with
compile_model(self, model: nn.Module)
- Compile the model and return it. This is called after the model is built and can be used to add change the compile behavior.
configure_criteria(self)
- Return the criteria that should be used.
configure_optimizers_only(self)
- Return the optimizers that should be used.
configure_schedulers(self, optimizers: List[optim.Optimizer])
- Return the schedulers that should be used.
shared_step(self, batch, batch_idx, phase: str)
- Function that is called by
training_step(...)
,validation_step(...)
,test_step(...)
andpredict_step(...)
from theAutoModule
with the fitting stage argument (train
/val
/test
/predict
)
- Function that is called by
import torch.nn as nn
from autolightning import AutoModule
class MyModel(AutoModule):
def config_model(self):
# Can put any logic here and can access the configuration
# via self.hparams
return nn.Linear(100, 10)
def configure_criteria(self):
return nn.MSELoss()
def shared_step(self, batch, batch_idx, phase: str):
X, y = batch
model = self.get_model()
criterion = self.criteria
loss = criterion(model(X), y)
self.log(f"{phase}/loss", loss)
return loss
Similar to models, you can customize the data loading behavior by using hooks. autolightning adds the following new hooks:
configure_configuration(self, cfg: Dict)
get_common_transform(self, moment: str)
get_common_target_transform(self, moment: str)
get_common_batch_transform(self, moment: str)
get_transform(self, stage: str)
get_target_transform(self, stage: str)
get_batch_transform(self, moment: str)
get_dataloader_kwargs(self, stage: str)
get_dataset(self, phase: str)
get_transformed_dataset(self, phase: str)
get_dataloader(self, phase: str)
import torch.nn as nn
from autolightning import AutoDataModule
class MyDataModule(AutoDataModule):
def get_dataset(self, split: str):
# Can put any logic here and can access the configuration
# via self.hparams
return MyDataset(split)