Hyperparameter Optimization

This document provides a walkthrough of the hyperparameter optimization example. To run the application, first install some dependencies.

pip install tensorflow

You can view the code for this example.

The simple script that processes results as they become available and launches new experiments can be run as follows.

python ray/examples/hyperopt/hyperopt_simple.py --trials=5 --steps=10

The variant that divides training into multiple segments and aggressively terminates poorly performing models can be run as follows.

python ray/examples/hyperopt/hyperopt_adaptive.py --num-starting-segments=5 \
                                                  --num-segments=10 \
                                                  --steps-per-segment=20

Machine learning algorithms often have a number of hyperparameters whose values must be chosen by the practitioner. For example, an optimization algorithm may have a step size, a decay rate, and a regularization coefficient. In a deep network, the network parameterization itself (e.g., the number of layers and the number of units per layer) can be considered a hyperparameter.

Choosing these parameters can be challenging, and so a common practice is to search over the space of hyperparameters. One approach that works surprisingly well is to randomly sample different options.

Problem Setup

Suppose that we want to train a convolutional network, but we aren’t sure how to choose the following hyperparameters:

  • the learning rate
  • the batch size
  • the dropout probability
  • the standard deviation of the distribution from which to initialize the network weights

Suppose that we’ve defined a remote function train_cnn_and_compute_accuracy, which takes values for these hyperparameters as its input (along with the dataset), trains a convolutional network using those hyperparameters, and returns the accuracy of the trained model on a validation set.

import numpy as np
import ray

@ray.remote
def train_cnn_and_compute_accuracy(hyperparameters,
                                   train_images,
                                   train_labels,
                                   validation_images,
                                   validation_labels):
  # Construct a deep network, train it, and return the accuracy on the
  # validation data.
  return np.random.uniform(0, 1)

Processing results as they become available

One problem with the above approach is that you have to wait for all of the experiments to finish before you can process the results. Instead, you may want to process the results as they become available, perhaps in order to adaptively choose new experiments to run, or perhaps simply so you know how well the experiments are doing. To process the results as they become available, we can use the ray.wait primitive.

The most simple usage is the following. This example is implemented in more detail in driver.py.

# Launch some experiments.
remaining_ids = []
for hyperparameters in hyperparameter_configurations:
  remaining_ids.append(train_cnn_and_compute_accuracy.remote(hyperparameters,
                                                             train_images,
                                                             train_labels,
                                                             validation_images,
                                                             validation_labels))

# Whenever a new experiment finishes, print the value and start a new
# experiment.
for i in range(100):
  ready_ids, remaining_ids = ray.wait(remaining_ids, num_returns=1)
  accuracy = ray.get(ready_ids[0])
  print("Accuracy is {}".format(accuracy))
  # Start a new experiment.
  new_hyperparameters = generate_hyperparameters()
  remaining_ids.append(train_cnn_and_compute_accuracy.remote(new_hyperparameters,
                                                             train_images,
                                                             train_labels,
                                                             validation_images,
                                                             validation_labels))