There are many components involved in training a PyTorch model, including model architectures, loss functions, hyperparameters, optimizers, dataloaders, and all of their arguments.
A standard training loop requires boilerplate code to connect all of these components, including training and validation steps, saving and loading checkpoints, and tracking metrics.
To simplify this process, a [π₯] template can be used to specify the aforementioned components using a TOML file, while implementing the remaining ones as a minimal class in a single PyTorch file.
- Define the configuration in a TOML file (e.g.
configs/example.toml
) - Train, Validate and Test the model with
python main.py configs/example.toml
A TOML file is read by the Trainer
class (the unique class that implements all
the boilerplate code for training) and dynamically loads classes and their
arguments using the init
function:
def init(module: object, class_args: dict):
class_name = class_args.pop("class")
return getattr(module, class_name)(**class_args)
Suppose the following TOML configuration for the optimizer:
[optimizer]
class = "Adam"
lr = 1e-3
weight_decay=0
From the [optimizer]
section, [π₯] uses the class
to create a new instance
of a torch.optim.Adam
optimizer and passes all other values as arguments to the
new object (here lr
and weight_decay
). Optimizer also has parameters
as a
positional argument but this is already provided by the code in the Learner
class.
You can also initialize from TOML your custom classes
[model]
class = "LeNet5"
num_classes = 10
This configuration section will initialize LeNet5
class defined in models/models.py
as the model architecture.
You can easily understand how a TOML file is loaded by Trainer
and Tester
by
comparing configs/example.toml
and __init__()
methods in main.py
.
-
Is [π₯] stable? No, I'm tweaking this template based on my experience and needs, so expect breaking changes. Nevertheless, this is a template so you might have to heavily modify it to fit your needs.
-
Why the name [π₯]? It's a combination of the PyTorch flame and the square brackets defining sections in a TOML file.
-
Why TOML? I think it's simpler than YAML and better than JSON for configuration. Moreover, the Python ecosystem starts to embrace it: tomllib in the standard library and pyproject.toml for Python project configuration.
- victoresque/pytorch-template: PyTorch deep learning projects made easy.
- moemen95/Pytorch-Project-Template: A scalable template for PyTorch projects, with examples in Image Segmentation, Object classification, GANs and Reinforcement Learning.