Library package for my line of PyTorch work. Includes convenience functions that are more semantic-oriented and improves readability. Also sports a lightweight framework that abstracts away standard routines and allows the user to quickly implement and compare various experiments.
To give some sense of just how handy the framework brought by the Experiment
class and library is for whipping up quick experiments, you may wish to compare between our code and PyTorch's guide for the same CIFAR10 classifier experiment.
Every individual experiment is identified by a directory. Experiment settings are fully described by the config.yml
within their respective directories. The example.config.yml
provides all the configurations supported out-of-the-box:
---
evaluation_metrics: # [OPTIONAL] The metrics which will be calculated for every epoch
- loss* # (DEFAULT VALUE) Can be omitted as loss is a compulsory metric. Asterisk indicates to set as criterion metric for selecting best epoch
- accuracy # Also computes accuracy for every batch
activation: # Activation function
choice: ReLU # Same name as function in in torch.nn
kwargs: {} # Parameters for activation function call. {} to indicate PyTorch defaults
optimization: # Optimization function
choice: SGD # Same name as function in torch.optim
kwargs: # Parameters for activation function call
lr: 0.001
momentum: 0.9
loss: # Loss function
choice: CrossEntropyLoss # Same name as function in in torch.nn
kwargs: {} # Parameters for activation function call. {} to indicate PyTorch defaults
regularization: # [OPTIONAL] Regularization terms (currently only supports L2)
L2: 0.001 # Lagrange multiplier (i.e. lambda) value for L2
batch_size: 4 # Mini-batch size for dataloaders
num_epochs: 5 # Number of training epochs
checkpoints: # [OPTIONAL] Checkpoint related configuration
dir: . # (DEFAULT VALUE) Directory under which checkpoints are saved.
best_prefix: best # (DEFAULT VALUE) Prefix of the best checkpoint. E.g. best.pth.tar
last_prefix: last # (DEFAULT VALUE) Prefix of the last checkpoint. E.g. last.pth.tar
suffix: .pth.tar # (DEFAULT VALUE) Filename suffix for checkpoint files
stats_filename: stats.yml # (DEFAULT VALUE) Filename for reviewing best and last checkpoint statistics. It will be saved in the same directory as checkpoints
state_dict_mappings: [] # (DEFAULT VALUE) Key mappings to translate before loading in the state dictionary. E.g. [(key_1_old, key_1_new), ...]
logs: # [OPTIONAL] Logging related configuration
logger: log.log # (DEFAULT VALUE) Filename which logger will log to. It sits in the same directory as the experiment
tensorboard: TB_logdir # (DEFAULT VALUE) The logdir for Tensorboard [DEFAULT]. It sits in the same directory as the experiment
remarks: Great experiment! # [OPTIONAL]
To read in this experiment configuration file, you'll need to first create a class that subclasses Experiment
. Your subclass must minimally override and implement the methods _init_model
, _init_dataset
and _init_dataloaders
. Please refer to these methods to see what class attributes have to be defined within these methods. Every Experiment
instance exposes the methods .train()
, .validation()
and .test()
to correspond to the respective contexts of running the model. In addition, analyze()
is meant for providing analytic outputs of a model's performance after training. Please refer to the CIFAR10 classifier as a simple example. The directory structure of example
also reflects the intention of how an experiment is organized for different settings.
You may also override any of the methods in Experiment
as you see fit. For instance, if you are using a custom loss function, you may simply override _init_loss_fn
in your subclass. In the CIFAR10 classifier, we override _update_iter_stats
to additionally keep track of a confusion matrix across the iterations in analyze context. There's also nothing stopping you from adding more custom configurations and overriding relevant methods to read from them.
Currently, only the best and last checkpoints will be saved for every experiment as this is sufficient for my line of work. Support for saving checkpoints every k iterations/epochs might be added in the future but is not a priority for now.
Too many to list manually for now, please trial and error and install neccessary dependencies yourself.
pip install git+https://github.com/jin-zhe/jinlib
Clone this repo locally:
git clone git@github.com:jin-zhe/jinlib.git <dest>
Go to the directory you cloned the repo in:
cd <dest>
Install via pip:
pip install -e .
Alternatively, you may also do the following but you will lose the ability to uninstall it over pip:
python setup.py install develop
Currently there are no documentation support but most functions are well commented for easy understanding. For a comprehensive overview of the main conveniences you get from this library, please see sample codes in example.