baal-org / baal

Bayesian active learning library for research and industrial usecases.

Home Page:https://baal.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

`train_on_dataset` much slower when using `ActiveLearningDataset` compared to torch `Dataset`

arthur-thuy opened this issue · comments

Describe the bug
When the entire pool is labelled (i.e. training on the entire training set), the train_on_dataset function is much slower when using an ActiveLearningDataset as compared to using a regular torch Dataset. In the MNIST experiment below, it is 17x slower (!!!).

I suspect this discrepancy is larger when the labelled pool is larger, because there is no difference when only using 20 labelled samples.

To Reproduce
In this gist, a LeNet-5 model with MC Dropout is trained on the entire MNIST data for 1 epoch. Note that this script does not perform any active learning as no acquisitions are done and the pool set is empty. The script was intended to compare training times across AL packages.

The script has an option use-ald, which uses the ActiveLearningDataset in the train_on_dataset function instead of the regular torch Dataset. Please refer to lines 83-94 in the gist for the relevant code.

Results are as follows:

  • python baal_mnist.py --use-ald => "Elapsed training time: 0:1:23"
  • python baal_mnist.py => "Elapsed training time: 0:0:5"

Here is the full output:

> python baal_mnist.py --use-ald
Use GPU: NVIDIA RTX A5000 for training
labelling 60000 observations
[1538876-MainThread] [baal.modelwrapper:train_on_dataset:83] 2023-06-22T09:35:08.962235Z [info     ] Starting training              dataset=60000 epoch=1
[1538876-MainThread] [baal.modelwrapper:train_on_dataset:94] 2023-06-22T09:36:32.093871Z [info     ] Training complete              train_loss=0.21927499771118164
Elapsed training time: 0:1:23
[1538876-MainThread] [baal.modelwrapper:test_on_dataset:123] 2023-06-22T09:36:32.101867Z [info     ] Starting evaluating            dataset=10000
[1538876-MainThread] [baal.modelwrapper:test_on_dataset:133] 2023-06-22T09:36:35.126586Z [info     ] Evaluation complete            test_loss=0.04848730191588402
{'dataset_size': 60000,
 'test_accuracy': 0.9842716455459595,
 'test_loss': 0.04848730191588402,
 'train_accuracy': 0.9318974018096924,
 'train_loss': 0.21927499771118164}
Elapsed total time: 0:1:27
> python baal_mnist.py
Use GPU: NVIDIA RTX A5000 for training
[1538621-MainThread] [baal.modelwrapper:train_on_dataset:83] 2023-06-22T09:34:51.757774Z [info     ] Starting training              dataset=60000 epoch=1
[1538621-MainThread] [baal.modelwrapper:train_on_dataset:94] 2023-06-22T09:34:56.868404Z [info     ] Training complete              train_loss=0.21591344475746155
Elapsed training time: 0:0:5
[1538621-MainThread] [baal.modelwrapper:test_on_dataset:123] 2023-06-22T09:34:56.874050Z [info     ] Starting evaluating            dataset=10000
[1538621-MainThread] [baal.modelwrapper:test_on_dataset:133] 2023-06-22T09:34:59.894939Z [info     ] Evaluation complete            test_loss=0.04452119022607803
{'dataset_size': 60000,
 'test_accuracy': 0.985236644744873,
 'test_loss': 0.04452119022607803,
 'train_accuracy': 0.9333688616752625,
 'train_loss': 0.21591344475746155}
Elapsed total time: 0:0:9

Expected behavior
I would expect the training time with ActiveLearningDataset to be a few percent slower, but not 17x slower.

Version (please complete the following information):

  • OS: Ubuntu 20.04
  • Python: 3.9.16
  • Baal version: 1.7.0

Additional context
I want to use active learning in my experiments, so just using the torch Dataset is not an appropriate solution.

Any ideas why this is the case and whether this could be fixed?
Thank you!

Hello!

I was able to reproduce with this example

# test that active learning is fast
from torchvision.datasets import CIFAR10
from baal.active.dataset import ActiveLearningDataset
dataset = CIFAR10(root='/tmp', train=True, download=True)
al_dataset = ActiveLearningDataset(dataset)
al_dataset.label_randomly(len(dataset))
%timeit [x for x in al_dataset]
%timeit [x for x in dataset]

I have a possible fix where we cache the result of ActiveLearningDataset.get_indices_for_active_step.

I'll try to merge this quickly and make a release, but I'm away for the long weekend. Coming back on Monday

I opened #265, not super happy with the solution, but that's the best I can do for now. Now it is "only" 2x slower, will revisit next week, but feel free to use the branch fix/al_dataset_speed for your experiments.

Thank you for the fix! I would be happy with the "only 2x slower" training time.

Hello!
Are you comfortable installing Baal from source on the branch fix/al_dataset_speed?

I want to be sure that we are not blocking you. If so, I'll immediately merge and deploy a minor release asap.

I’m currently on holiday and need to work on a paper revision when I return to the office (not related to Baal). As such, I’ll not be working with Baal the next 4 weeks so the fix is not urgent for me.

If the minor release is not done by then, I’ll install it from source. Thank you for your message.