GSK-AI / afterglow

A package for uncertainty estimation with PyTorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Afterglow Logo


Coverage

Afterglow provides your PyTorch models with uncertainty estimation capabilites. It's designed to work with any PyTorch model, with a minimum of fuss. It uses SWAG as its core uncertainty esitmation method.

With afterglow, you can transform code that trains point-estimating models into code that trains uncertainty-estimating models using a single line:

from afterglow import enable_swag
enable_swag(
    model,
    start_iteration=100 * len(train_dataloader), # start tracking at epoch 100
    update_period_in_iters=len(train_dataloader), # update posterior every epoch
    max_cols=20,
)

After training your model as usual, you can obtain uncertainty estimates like so:

mean, std = model.trajectory_tracker.predict_uncertainty(x, num_samples=30)

You can sample single instances of the model from the SWAG posterior:

model.trajectory_tracker.sample_state()
sample_at_x = model(x)

You can efficiently predict on an entire dataloader, drawing one sample for each pass over the dataset:

dataset_means, dataset_stds = model.trajectroy_tracker.predict_uncertainty_on_dataloader(
    dataloader=dataloder, num_samples=30
)

If you pass a dataloader to enable_swag, the SWAG batchnorm update step will be taken care of for you:

from afterglow import enable_swag
enable_swag(
    model,
    start_iteration=100 * len(train_dataloader),
    update_period_in_iters=len(train_dataloader),
    max_cols=20,
    dataloader_for_batchnorm=train_dataloader, # now we'll do the bn update when we sample
)

You can speed online inference up by limiting the number of samples used to update batchnorm parameters:

from afterglow import enable_swag
enable_swag(
    model,
    start_iteration=100 * len(train_dataloader),
    update_period_in_iters=len(train_dataloader),
    max_cols=20,
    dataloader_for_batchnorm=train_dataloader,
    num_datapoints_for_bn_update=10, # now we'll only use 10 examples for the bn update
)

See the documentation, and the example app in /example, for more!

About

A package for uncertainty estimation with PyTorch

License:Apache License 2.0


Languages

Language:Python 100.0%