jin-zhe / jinlib

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

jinlib

Library package for my line of PyTorch work. Includes convenience functions that are more semantic-oriented and improves readability. Also sports a lightweight framework that abstracts away standard routines and allows the user to quickly implement and compare various experiments.

To give some sense of just how handy the framework brought by the Experiment class and library is for whipping up quick experiments, you may wish to compare between our code and PyTorch's guide for the same CIFAR10 classifier experiment.

Overview

Every individual experiment is identified by a directory. Experiment settings are fully described by the config.yml within their respective directories. The example.config.yml provides all the configurations supported out-of-the-box:

---
evaluation_metrics:         # [OPTIONAL] The metrics which will be calculated for every epoch
  - loss*                   # (DEFAULT VALUE) Can be omitted as loss is a compulsory metric. Asterisk indicates to set as criterion metric for selecting best epoch
  - accuracy                # Also computes accuracy for every batch
activation:                 # Activation function
  choice: ReLU              # Same name as function in in torch.nn
  kwargs: {}                # Parameters for activation function call. {} to indicate PyTorch defaults
optimization:               # Optimization function
  choice: SGD               # Same name as function in torch.optim
  kwargs:                   # Parameters for activation function call
    lr: 0.001
    momentum: 0.9
loss:                       # Loss function
  choice: CrossEntropyLoss  # Same name as function in in torch.nn
  kwargs: {}                # Parameters for activation function call. {} to indicate PyTorch defaults
regularization:             # [OPTIONAL] Regularization terms (currently only supports L2)
  L2: 0.001                 # Lagrange multiplier (i.e. lambda) value for L2
batch_size: 4               # Mini-batch size for dataloaders
num_epochs: 5               # Number of training epochs
checkpoints:                # [OPTIONAL] Checkpoint related configuration
  dir: .                    # (DEFAULT VALUE) Directory under which checkpoints are saved.
  best_prefix: best         # (DEFAULT VALUE) Prefix of the best checkpoint. E.g. best.pth.tar
  last_prefix: last         # (DEFAULT VALUE) Prefix of the last checkpoint. E.g. last.pth.tar
  suffix: .pth.tar          # (DEFAULT VALUE) Filename suffix for checkpoint files
  stats_filename: stats.yml # (DEFAULT VALUE) Filename for reviewing best and last checkpoint statistics. It will be saved in the same directory as checkpoints
  state_dict_mappings: []   # (DEFAULT VALUE) Key mappings to translate before loading in the state dictionary. E.g. [(key_1_old, key_1_new), ...]
logs:                       # [OPTIONAL] Logging related configuration
  logger: log.log           # (DEFAULT VALUE) Filename which logger will log to. It sits in the same directory as the experiment
  tensorboard: TB_logdir    # (DEFAULT VALUE) The logdir for Tensorboard [DEFAULT]. It sits in the same directory as the experiment
remarks: Great experiment!  # [OPTIONAL]

To read in this experiment configuration file, you'll need to first create a class that subclasses Experiment. Your subclass must minimally override and implement the methods _init_model, _init_dataset and _init_dataloaders. Please refer to these methods to see what class attributes have to be defined within these methods. Every Experiment instance exposes the methods .train(), .validation() and .test() to correspond to the respective contexts of running the model. In addition, analyze() is meant for providing analytic outputs of a model's performance after training. Please refer to the CIFAR10 classifier as a simple example. The directory structure of example also reflects the intention of how an experiment is organized for different settings.

You may also override any of the methods in Experiment as you see fit. For instance, if you are using a custom loss function, you may simply override _init_loss_fn in your subclass. In the CIFAR10 classifier, we override _update_iter_stats to additionally keep track of a confusion matrix across the iterations in analyze context. There's also nothing stopping you from adding more custom configurations and overriding relevant methods to read from them.

Currently, only the best and last checkpoints will be saved for every experiment as this is sufficient for my line of work. Support for saving checkpoints every k iterations/epochs might be added in the future but is not a priority for now.

Install instructions

Dependencies

Too many to list manually for now, please trial and error and install neccessary dependencies yourself.

Via pip

pip install git+https://github.com/jin-zhe/jinlib

Via local repo

Clone this repo locally:

git clone git@github.com:jin-zhe/jinlib.git <dest>

Go to the directory you cloned the repo in:

cd <dest>

Install via pip:

pip install -e .

Alternatively, you may also do the following but you will lose the ability to uninstall it over pip:

python setup.py install develop

Documentation

Currently there are no documentation support but most functions are well commented for easy understanding. For a comprehensive overview of the main conveniences you get from this library, please see sample codes in example.

About


Languages

Language:Python 100.0%