EDGSCOUT / pytorch-energizer

An active learning library for Pytorch based on Pytorch-Lightning.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Pytorch-Energizer

pypi python Build Status codecov

An active learning library for Pytorch

Features

Energizer allows training Pytorch models using active learning. Being based on Pytorch-Lightning, it can easily scale to multi-node/multi-gpu settings. Also, importantly, abiding to the light-weight Pytorch-Lightning API allows the this library to have a tidy interface and completely avoid boilerplate training code.

The core principle underlying Energizer is composability. Everything in the library revolves around the EnergizerStrategy which puts together a base_learner (the model we want to actively train), an inference_module (how the model should behave when run on the pool dataset), and the active learning loop hyper-parameters.

For example, assume you have the following model

class Model(LightningModule):
    def __init__(self):
        super().__init__()
        self.backbone = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny", num_labels=4)
        self.loss = torch.nn.CrossEntropyLoss()

    def forward(self, batch):
        return self.backbone(**batch).logits

    def step(self, batch, *args, **kwargs):
        y = batch.pop("labels")
        y_hat = self(batch)
        return self.loss(y_hat, y)

    def training_step(self, batch, *args, **kwargs):
        loss = self.step(batch, *args, **kwargs)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, *args, **kwargs):
        loss = self.step(batch, *args, **kwargs)
        self.log("val_loss", loss)
        return loss

    def test_step(self, batch, *args, **kwargs):
        loss = self.step(batch, *args, **kwargs)
        self.log("test_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)


base_learner = Model()

The first step to let Energizer know how this model should behave at inference time on the pool dataset. This is easily done by wrapping it into an EnergizerInference module. Let's say that you want to use MC-Dropout. You can then do

from energizer.inference import MCDropout

inference_module = MCDropout(
    num_inference_iters: int = 10,
    consistent: bool = False,
    prob: Optional[float] = 0.1,
    inplace: bool = True,
)

# this will patch all Dropout layers
inference_module.connect(base_learner)  # NOTE: when used inside an `EnergizerStrategy`
                                        # this will be done automatically

Now whenever inference_module(x) is called it will perform num_inference_iter forward passes with the dropout layers activated and collect the resulting list of logits, as prescribed by the MC-Dropout technique. To actually tell how to score instances from the pool and how to select indices, we can wrap the inference module into an EnergyStrategy. For this example, let's assume you want to use the entropy strategy that selects the instances with the highest entropy of the logits

from energizer.strategies import EntropyStrategy

al_strategy = EntropyStrategy(inference_module=inference_module)

Under the hood this will call the inference_module.forward(). Since our inference module performs MC-Dropout, when will automatically use the expected entropy. Each EnergizerStrategy is a LightingModule whose test_step has been tailored to perform scoring and selection of the instances to label. In practice, each batch from the pool dataset is scored. The top-k scores are kept in memory at each iteration alongside their indices. This avoids scoring the entire pool first and then computing the top-k, which can be unfeasible when the pool is very big and does not play nicely with distributed settings.

The missing piece is the actual active learning loop definition. In Energizer this is handled by the ActiveLearningLoop, which is a subclass of the Lighting FitLoop. It can be defined as shown below. Also, the next step shows how to do everything (define an inference module, strategy, and loop) in one go.

from energizer.loops import ActiveLearningLoop
from energizer.strategies import EntropyStrategy
from pytorch_lightning import Trainer


# define model
base_learner = Model()

# define active learning loop, strategy, and inference module (no need to call `.connect()`)
active_learning_loop = ActiveLearningLoop(
    al_strategy=EntropyStrategy(
        inference_module=MCDropout(
            num_inference_iters: int = 10,
            consistent: bool = False,
            prob: Optional[float] = 0.1,
            inplace: bool = True,
        )
    ),
    query_size: int = 2,             # number of instance to label at each round
    reset_weights: bool = True,      # should we reset the model weights after each iteration?
    label_epoch_frequency: int = 3,  # how many training epochs on the labelled data
)

trainer = Trainer(max_epochs=10)

# Connect to the default fit_loop of the trainer to extract info, e.g. max_epochs
# NOTE: there is no need to call `.connect()` on the strategy or on the inference module,
# everything is handled by this `.connect()` call
active_learning_loop.connect(trainer)

# replace the original fit_loop with the active_learning_loop
trainer.fit_loop = active_learning_loop

# fit model with active learning
trainer.fit(base_learner, datamodule=dm)

Credits

This package was created with Cookiecutter and the waynerv/cookiecutter-pypackage project template.

About

An active learning library for Pytorch based on Pytorch-Lightning.

License:MIT License


Languages

Language:Python 81.9%Language:Jupyter Notebook 17.8%Language:Makefile 0.3%