pytorch / ignite

High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.

Home Page:https://pytorch-ignite.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

FastaiLRFinder usage

leej3 opened this issue Β· comments

πŸ“š FastaiLRFinder Documentation

I was trying out the FastaiLRFinder and I got a confusing error message. The error message or documentation could probably be improved a little for this. I'm happy to contribute a fix but I'm not sure whether the docs or the error report would be a better choice.

The error can be reproduced by modifying the output of the ignite code generator for an object classification task. I made some changes to use the finder to set the learning rate.

See diff
diff --git a/main.py b/main.py
index d177d6b..244f869 100644
--- a/main.py
+++ b/main.py
@@ -4,7 +4,7 @@ from typing import Any
 import ignite.distributed as idist
 from data import setup_data
 from ignite.engine import Events
-from ignite.handlers import PiecewiseLinear
+from ignite.handlers import PiecewiseLinear, FastaiLRFinder
 from ignite.metrics import Accuracy, Loss
 from ignite.utils import manual_seed
 from models import setup_model
@@ -13,7 +13,7 @@ from trainers import setup_evaluator, setup_trainer
 from utils import *


-def run(local_rank: int, config: Any):
+def run(local_rank: int, config: Any,use_finder=True):
     # make a certain seed
     rank = idist.get_rank()
     manual_seed(config.seed + rank)
@@ -33,19 +33,41 @@ def run(local_rank: int, config: Any):
     model = idist.auto_model(setup_model(config.model))
     optimizer = idist.auto_optim(optim.Adam(model.parameters(), lr=config.lr))
     loss_fn = nn.CrossEntropyLoss().to(device=device)
+    # trainer and evaluator
+    trainer = setup_trainer(config, model, optimizer, loss_fn, device, dataloader_train.sampler)
+    evaluator = setup_evaluator(config, model, device)
+
+    if use_finder:
+        lr_finder = FastaiLRFinder()
+        to_save = {"model": model, "optimizer": optimizer}
+
+        with lr_finder.attach(trainer, to_save=to_save) as trainer_with_lr_finder:
+            trainer_with_lr_finder.run(dataloader_train)
+
+        # Get lr_finder results
+        lr_finder.get_results()
+
+        # Plot lr_finder results (requires matplotlib)
+        lr_finder.plot()
+
+        # get lr_finder suggestion for lr
+        finder_rate = lr_finder.lr_suggestion()
+        print(f"{finder_rate}!!!!!!!!!!")
+        optimizer = idist.auto_optim(optim.Adam(model.parameters()),lr=finder_rate)
+
+    learning_rate = finder_rate if use_finder else config.lr
+
     milestones_values = [
         (0, 0.0),
         (
             len(dataloader_train),
-            config.lr,
+            learning_rate,
         ),
         (config.max_epochs * len(dataloader_train), 0.0),
     ]
     lr_scheduler = PiecewiseLinear(optimizer, "lr", milestones_values=milestones_values)

-    # trainer and evaluator
-    trainer = setup_trainer(config, model, optimizer, loss_fn, device, dataloader_train.sampler)
-    evaluator = setup_evaluator(config, model, device)
+

     # attach metrics to evaluator
     accuracy = Accuracy(device=device)

This fails because the training loss is returned as a dict. The error reported is:



Current run is terminating due to exception: output of the engine should be of type float or 0d torch.Tensor or 1d torch.Tensor with 1 element, but got output of type dict
Engine run is terminating due to exception: output of the engine should be of type float or 0d torch.Tensor or 1d torch.Tensor with 1 element, but got output of type dict
Traceback (most recent call last):
  File "/workspace/ignite/ignite-template-vision-classification/main.py", line 159, in <module>
    main()
  File "/workspace/ignite/ignite-template-vision-classification/main.py", line 155, in main
    p.run(run, config=config)
  File "/opt/conda/lib/python3.10/site-packages/ignite/distributed/launcher.py", line 316, in run
    func(local_rank, *args, **kwargs)
  File "/workspace/ignite/ignite-template-vision-classification/main.py", line 45, in run
    trainer_with_lr_finder.run(dataloader_train)
  File "/opt/conda/lib/python3.10/site-packages/ignite/engine/engine.py", line 898, in run
    return self._internal_run()
  File "/opt/conda/lib/python3.10/site-packages/ignite/engine/engine.py", line 941, in _internal_run
    return next(self._internal_run_generator)
  File "/opt/conda/lib/python3.10/site-packages/ignite/engine/engine.py", line 999, in _internal_run_as_gen
    self._handle_exception(e)
  File "/opt/conda/lib/python3.10/site-packages/ignite/engine/engine.py", line 644, in _handle_exception
    raise e
  File "/opt/conda/lib/python3.10/site-packages/ignite/engine/engine.py", line 965, in _internal_run_as_gen
    epoch_time_taken += yield from self._run_once_on_dataset_as_gen()
  File "/opt/conda/lib/python3.10/site-packages/ignite/engine/engine.py", line 1093, in _run_once_on_dataset_as_gen
    self._handle_exception(e)
  File "/opt/conda/lib/python3.10/site-packages/ignite/engine/engine.py", line 644, in _handle_exception
    raise e
  File "/opt/conda/lib/python3.10/site-packages/ignite/engine/engine.py", line 1075, in _run_once_on_dataset_as_gen
    self._fire_event(Events.ITERATION_COMPLETED)
  File "/opt/conda/lib/python3.10/site-packages/ignite/engine/engine.py", line 431, in _fire_event
    func(*first, *(event_args + others), **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/ignite/handlers/lr_finder.py", line 166, in _log_lr_and_loss
    raise TypeError(
TypeError: output of the engine should be of type float or 0d torch.Tensor or 1d torch.Tensor with 1 element, but got output of type dict

The finder does not expect the format that is returned ({"train_loss": xx.xxx}). This can be fixed by using an output transform for the finder:

with lr_finder.attach(trainer, to_save=to_save,output_transform=lambda x:x["train_loss"]) as trainer_with_lr_finder:
    trainer_with_lr_finder.run(dataloader_train)

Thanks for reporting, John! I would say it would be helpful to improve error message such that users seeing it know what to do directly. Changing API ref docs can be optional, I would say, but will be happy to approve if you change both. So, if you are eager to contribute this improvement, you are more than welcome ;)