ltiao / syne-tune

Large scale and asynchronous Hyperparameter Optimization at your fingertip.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Syne Tune

Release Python Version License Downloads

This package provides state-of-the-art distributed hyperparameter optimizers (HPO) where trials can be evaluated with several trial backend options (local backend to evaluate trials locally; SageMaker to evaluate trials as separate SageMaker training jobs; a simulation backend to quickly benchmark parallel asynchronous schedulers).

Installing

To install Syne Tune from pip, you can simply do:

pip install 'syne-tune'

This will install a bare-bone version. If you want in addition to install our own Gaussian process based optimizers, Ray Tune or Bore optimizer, you can run pip install 'syne-tune[X]' where X can be

  • gpsearchers: For built-in Gaussian process based optimizers
  • raytune: For Ray Tune optimizers
  • benchmarks: For installing all dependencies required to run all benchmarks
  • extra: For installing all the above
  • bore: For Bore optimizer
  • kde: For KDE optimizer

For instance, pip install 'syne-tune[gpsearchers]' will install Syne Tune along with many built-in Gaussian process optimizers.

To install the latest version from git, run the following:

pip install git+https://github.com/awslabs/syne-tune.git

For local development, we recommend to use the following setup which will enable you to easily test your changes:

pip install --upgrade pip
git clone https://github.com/awslabs/syne-tune.git
cd syne-tune
pip install -e '.[extra]'

To run unit tests, simply run pytest in the root of this repository.

To run all tests whose name begins with test_async_scheduler, you can use the following

pytest -k test_async_scheduler

Getting started

To enable tuning, you have to report metrics from a training script so that they can be communicated later to Syne Tune, this can be accomplished by just calling report(epoch=epoch, loss=loss) as shown in the example bellow:

# train_height.py
import logging
import time

from syne_tune import Reporter
from argparse import ArgumentParser

if __name__ == '__main__':
    root = logging.getLogger()
    root.setLevel(logging.INFO)

    parser = ArgumentParser()
    parser.add_argument('--steps', type=int)
    parser.add_argument('--width', type=float)
    parser.add_argument('--height', type=float)

    args, _ = parser.parse_known_args()
    report = Reporter()

    for step in range(args.steps):
        dummy_score = (0.1 + args.width * step / 100) ** (-1) + args.height * 0.1
        # Feed the score back to Syne Tune.
        report(step=step, mean_loss=dummy_score, epoch=step + 1)
        time.sleep(0.1)

Once you have a script reporting metric, you can launch a tuning as-follow:

from syne_tune import Tuner, StoppingCriterion
from syne_tune.backend import LocalBackend
from syne_tune.config_space import randint
from syne_tune.optimizer.baselines import ASHA

# hyperparameter search space to consider
config_space = {
    'steps': 100,
    'width': randint(1, 20),
    'height': randint(1, 20),
}

tuner = Tuner(
    trial_backend=LocalBackend(entry_point='train_height.py'),
    scheduler=ASHA(
        config_space, metric='mean_loss', resource_attr='epoch', max_t=100,
        search_options={'debug_log': False},
    ),
    stop_criterion=StoppingCriterion(max_wallclock_time=15),
    n_workers=4,  # how many trials are evaluated in parallel
)
tuner.run()

The above example runs ASHA with 4 asynchronous workers on a local machine.

Examples

You will find the following examples in examples/ folder illustrating different functionalities provided by Syne Tune:

FAQ and Tutorials

You can check our FAQ, to learn more about Syne Tune functionalities.

Do you want to know more? Here are a number of tutorials.

Security

See CONTRIBUTING for more information.

License

This project is licensed under the Apache-2.0 License.

About

Large scale and asynchronous Hyperparameter Optimization at your fingertip.

License:Apache License 2.0


Languages

Language:Python 99.7%Language:Shell 0.2%Language:Dockerfile 0.1%