rotcx / BayeFormers

General API for Deep Bayesian Variational Inference by Backpropagation. The repository has been designed to work with Transformers like architectures. Compatible with the HuggingFace Transformers models.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Logo

License: MIT Python 3.6+ Pytorch 1.4+

General API for Deep Bayesian Variational Inference by Backpropagation.
The repository has been designed to work with Transformers like architectures.
Compatible with the HuggingFace Transformers models.

Setup

Installation of the required python libraries is done through pip.

$ cd BayeFormers
$ (sudo) pip3 install -r requirements.txt

Usage

from bayeformers import to_bayesian

import bayeformers.nn as bnn
import torch
import torch.nn as nn
import torch.nn.functional as F


# Frequentist Model Definition
class Model(nn.Module):
    pass


# Train Frequentist Model
model = Model()

predictions = model(inputs)
loss = F.nll(inputs, labels, reduction="sum")

# Turn Frequentist Model to Bayesian Model (MOPED Initializatipn)
bayesian_model = to_bayesian(model, delta=0.05, freeze=True)

# Train Bayesian Model
predictions = torch.zeros(samples, batch_size, *output_dim)
log_prior = torch.zeros(samples, batch_size)
log_variational_posterior = torch.zeros(samples, batch_size)

for s in samples:
    predictions[s] = bayesian_model(inputs)
    log_prior[s] = bayesian_model.log_prior()
    log_variational_posterior[s] = bayesian_model.log_variational_posterior()

predictions = predictions.mean(0)
log_prior = log_prior.mean(0)
log_variational_posterior = log_variational_posterior.mean(0)

nll = F.nll(predictions, labels, reduction="sum")
loss = (log_variational_posterior - log_prior) / n_batches + nll

Examples

$ python3 -m examples.mlp_mnist
$ python3 -m examples.bert_glue --help
$ python3 -m examples.bert_squad --help

References

Libraries

Papers

  • "Weight Uncertainty in Neural Networks", Blundell et al., ICML 2015, Arxiv
  • "Specifying Weight Priors in Bayesian Deep Neural Networks with Empirical Bayes", Krishnan et al., AAAI 2020, Arxiv

Articles

  • "Bayesian inference: How we are able to chase the Posterior", Ritchie Vink, Blog
  • "Weight Uncertainty in Neural Networks", Nitarshan Rajkumar, Blog

About

General API for Deep Bayesian Variational Inference by Backpropagation. The repository has been designed to work with Transformers like architectures. Compatible with the HuggingFace Transformers models.

License:MIT License


Languages

Language:Python 98.7%Language:Makefile 1.3%