(2023-11-01 updates: Refreshed the charts, other small tweaks, posted reproducible experiments.)

This post has two parts:

  1. A fun point of view that I find compelling
  2. A project that follows from that point of view, with some early results

Point of view

When you tune an AI model’s design and its training regime, you are exploring a search space. Bayesian Optimization is a framework that naturally arises when you try to automate your intuitive process of exploring a search space. The promise of Bayesian Optimization is: “If I can tell a computer my beliefs about this search space, then the computer can perform the search better than I can.”

From what I’ve seen, AI / ML people tend to feel burned by their experiences with Bayesian Optimization. Often I hear stories where they once invested a bunch of time into trying it on toy projects, but it never managed to cross over and be useful in their actual work where they continue to use manual search and non-sophisticated brute force searches. Often they feel guilty about all the time they wasted trying Bayesian Optimization.

Here is a slightly contrarian view on why Bayesian Optimization has not yet become ubiquitous.

Lots of people have tried to increase adoption of Bayesian Optimization by making it easier, creating products and libraries that hide the messy details. My hunch is that existing Bayesian Optimization solutions block / discourage the user from taking ownership of the process. In particular, there is a lot of unrealized potential in giving users more powerful ways to state their beliefs about a search space. I think what is needed is a set of tools / documentation / cookbooks designed for people who are willing to put in the time to master their tools. Existing tools could grow to fill this role.

I think Bayesian Optimization ought to be less one-size-fits-all. I think it should be more hands-on.

If you use a one-size-fits-all Bayesian model, your search will be data-intensive and will often be much less efficient than a manual search. Results will be much better if you invest some time telling the model your priors about your parameters. This is more fun than you might expect.

Consider: when you manually explore a search space, you have some intuition about which details of the model are “orthogonal”. For example, hyperparameters describing a neural network’s architecture seem like they are roughly independent of hyperparameters that describe the training regime, and we can afford to optimize these groups of parameters almost independently. Meanwhile, hyperparameters like “momentum” and “learning rate” are intertwined, and it is important to explore their joint space. A one-size-fits-all Bayesian Optimization model totally lacks this intuition. Fortunately, it is easy to input these beliefs to a Gaussian Process (GP). If you have two groups of parameters that are independent, simply dedicate a kernel to each one, then add the kernels together. If they are dependent, multiply the kernels together. If it’s somewhere in between, do both, and allow the model to adapt to the data. If you run your own Bayesian Optimization code locally, this is all trivial to implement, but it is impossible in any Bayesian Optimization products that I’ve seen. (Some products use a fully AutoML approach to infer these dependencies, but this is a data-intensive process.)

In owning your search, another important degree of freedom is choosing useful features for the GP kernel. Your hyperparameters might not be a good basis for a GP. You might want to configure your GP to transform the hyperparameter vector into a basis that enables the GP to make better predictions with fewer data points, maybe simply by adding a few redundant features to the hyperparameter vectors that may prove useful. This insight can be combined with the “orthogonal” trick above to greatly reduce the effective size of the search space.

Opportunities also lie in having fine-grained control over the optimization loop. For example, I have found it useful to create a “tick-tock” optimization loop that alternates between multiple objectives (accuracy and training time).

The project: Is hands-on Bayesian Optimization worthwhile?

This question has multiple facets, each involving wearing a different hat:

  • Is it usable? What is the right user interface? What is the overall on-ramp / learning process for a person motivated to learn hands-on Bayesian Optimization?
  • Is it fast? I am proposing we use more complicated kernels. Is that tractible, from a performance standpoint?
  • Does it work? Are there considerable benefits in hands-on Bayesian Optimization, relative to a one-size-fits-all approach? Is it much better than random search?

Is it usable?

There is a user interface problem here that needs solving. The first UI to build is a first-class programming interface. Later we can consider whether a graphical UI is needed.

I draw inspiration from D3, a library for data visualization. In a world of people trying to build “easy” visualization tools, D3 took a different approach. D3 doesn’t generate visualizations; D3 helps you program your own visualizations. Rather than exposing some magic function that outputs a customized pre-built visualization, D3 provides a useful set of primitives, then it gets out of the way and lets you write code. It takes time to learn, but once you have learned it, you are powerful. It’s usually a tortoise-and-the-hare phenomenon; the person who chooses D3 ends up winning the race (and winning with style).

I think BoTorch (with gpytorch) is a good start at filling this role for Bayesian Optimization, but it’s not finished. It provides the core algorithmic tools of Bayesian Optimization, but doesn’t provide a way for describing a search space (for example). When you read botorch tutorials, they essentially recommend that you don’t use botorch directly, but instead use a wrapper like Ax, but Ax is more of a one-size-fits-all framework, adding a lot of unnecessary friction to hands-on Bayesian Optimization. I would rather see botorch grow outward into a usable library, giving you all the pieces you need to write your own Bayesian Optimization. They could aim to be the D3 of Bayesian Optimization. I found myself needing to “finish” botorch with my own makeshift library that adds search spaces with conditional parameters, an expanded set of input transform utilities, and random search utilities.

Is it fast?

I am proposing that we take one-size-fits-all GP kernels and replace them with customized compositions of kernels. I was worried that these compositions of kernels would be slow, and, yes, by default, they are. But if you write them efficiently, they are very fast. I ended up writing a WeightedSPSMaternKernel, i.e. a “weighted sum of products of sums Matern kernel” which runs many kernels in parallel. This satisfied my needs.

Otherwise, I’m also proposing we start performing transformations on feature vectors. This is already a supported use case and is fast.

Optional reading: In case you want details, here is my code that describes a 22-dimensonal hyperparameter space and how it is transformed into a search space (for random search / optimization) and then into a feature space (for the GP). Vexpr and outerloop are libraries I created. Vexpr makes it possible to write fast, readable compositional GP kernels, and outerloop attempts to “finish” botorch.

import botorch
import gpytorch
import outerloop as ol
import outerloop.vexpr.torch as ovt
import torch
import vexpr as vp
import vexpr.torch as vtorch
import vexpr.custom.torch as vctorch

parameter_space = [
    ol.Choice("optimizer", ["adam", "sgd"]),
    ol.Choice("nesterov", [True, False],
              condition=lambda choices: choices["optimizer"] == "sgd"),
    ol.Int("epochs", 2, 60),
    ol.Int("batch_size", 16, 4096),

    ol.Scalar("conv1_weight_decay", 1e-7, 3e-1),
    ol.Scalar("conv2_weight_decay", 1e-7, 3e-1),
    ol.Scalar("conv3_weight_decay", 1e-7, 3e-1),
    ol.Scalar("dense1_weight_decay", 1e-7, 3e-1),
    ol.Scalar("dense2_weight_decay", 1e-7, 3e-1),

    ol.Scalar("1cycle_initial_lr_pct", 1/80, 1/2),
    ol.Scalar("1cycle_final_lr_pct", 1/30000, 1/100),
    ol.Scalar("1cycle_pct_warmup", 0.01, 0.5),
    ol.Scalar("1cycle_max_lr", 0.01, 20.0),
    ol.Scalar("1cycle_max_momentum", 0, 0.9999,
              condition=lambda choices: choices["optimizer"] == "sgd"),
    ol.Scalar("1cycle_min_momentum_pct", 0.0, 1.0,
              condition=lambda choices: choices["optimizer"] == "sgd"),

    ol.Int("conv1_channels", 4, 64),
    ol.Int("conv2_channels", 8, 128),
    ol.Int("conv3_channels", 16, 256),
    ol.Int("dense1_units", 8, 256),
]

# Transform to log-space
xform = ol.transforms.ToScalarSpace(
    space,
    ol.transforms.log({
        "epochs": "log_epochs",
        "batch_size": "log_batch_size",
        "1cycle_initial_lr_pct": "log_1cycle_initial_lr_pct",
        "1cycle_final_lr_pct": "log_1cycle_final_lr_pct",
        "1cycle_max_lr": "log_1cycle_max_lr",
        "1cycle_pct_warmup": "log_1cycle_pct_warmup",
        "1cycle_momentum_max_damping_factor":
        "log_1cycle_momentum_max_damping_factor",
        "1cycle_momentum_min_damping_factor_pct":
        "log_1cycle_momentum_min_damping_factor_pct",
        "1cycle_beta1_max_damping_factor":
        "log_1cycle_beta1_max_damping_factor",
        "1cycle_beta1_min_damping_factor_pct":
        "log_1cycle_beta1_min_damping_factor_pct",
        "beta2_damping_factor": "log_beta2_damping_factor",
        "conv1_weight_decay": "log_conv1_weight_decay",
        "conv2_weight_decay": "log_conv2_weight_decay",
        "conv3_weight_decay": "log_conv3_weight_decay",
        "dense1_weight_decay": "log_dense1_weight_decay",
        "dense2_weight_decay": "log_dense2_weight_decay",
        "conv1_channels": "log_conv1_channels",
        "conv2_channels": "log_conv2_channels",
        "conv3_channels": "log_conv3_channels",
        "dense1_units": "log_dense1_units",
    })
)


class VexprHandsOnLossModel(botorch.models.SingleTaskGP):
    def __init__(self, ...)

       # Transforms for the kernel

        xforms += [
            ol.transforms.append_mean(
                ["log_conv1_channels", "log_conv2_channels",
                 "log_conv3_channels", "log_dense1_units"],
                "log_gmean_channels_and_units"),
            ol.transforms.subtract(
                {"log_conv1_channels": "log_conv1_channels_div_gmean",
                 "log_conv2_channels": "log_conv2_channels_div_gmean",
                 "log_conv3_channels": "log_conv3_channels_div_gmean",
                 "log_dense1_units": "log_dense1_units_div_gmean"},
                "log_gmean_channels_and_units"),
            ol.transforms.add(
                {"log_1cycle_initial_lr_pct": "log_1cycle_initial_lr",
                 "log_1cycle_final_lr_pct": "log_1cycle_final_lr"},
                "log_1cycle_max_lr"),
            ol.transforms.add(
                {"log_1cycle_momentum_min_damping_factor_pct":
                 "log_1cycle_momentum_min_damping_factor"},
                "log_1cycle_momentum_max_damping_factor"),
            ol.transforms.add(
                {"log_1cycle_beta1_min_damping_factor_pct":
                 "log_1cycle_beta1_min_damping_factor"},
                "log_1cycle_beta1_max_damping_factor"),
            ol.transforms.append_mean(
                ["log_conv1_weight_decay", "log_conv2_weight_decay",
                 "log_conv3_weight_decay", "log_dense1_weight_decay",
                 "log_dense2_weight_decay"],
                "log_gmean_weight_decay"),
            ol.transforms.subtract(
                {"log_conv1_weight_decay": "log_conv1_wd_div_gmean",
                 "log_conv2_weight_decay": "log_conv2_wd_div_gmean",
                 "log_conv3_weight_decay": "log_conv3_wd_div_gmean",
                 "log_dense1_weight_decay": "log_dense1_wd_div_gmean",
                 "log_dense2_weight_decay": "log_dense2_wd_div_gmean"},
                "log_gmean_weight_decay"),
            partial(ol.transforms.ChoiceNHotProjection,
                    out_name=N_HOT_PREFIX)
        ]

        # ...


def make_handson_kernel(space, batch_shape=()):
    """
    This kernel attempts to group parameters into orthogonal groups, while
    also always allowing for the model to learn to use the joint space.
    """
    zero_one_exclusive = partial(gpytorch.constraints.Interval,
                                 1e-6,
                                 1 - 1e-6)

    state = State(batch_shape)

    ialloc = IndexAllocator()

    lengthscale = vp.symbol("lengthscale")
    x1 = vp.symbol("x1")
    x2 = vp.symbol("x2")

    def index_for_name(name):
        return next(i for i, p in enumerate(space) if p.name == name)

    def scalar_kernel(names):
        ls_indices = ialloc.allocate(len(names))
        indices = [index_for_name(name) for name in names]
        return ovt.matern(
            vtorch.cdist(x1[..., indices] / lengthscale[ls_indices],
                         x2[..., indices] / lengthscale[ls_indices],
                         p=2),
            nu=2.5)

    def choice_kernel(names):
        ls_indices = ialloc.allocate(len(names))
        indices = [index_for_name(name) for name in names]
        return ovt.matern(
            vtorch.cdist(x1[..., indices] / lengthscale[ls_indices],
                         x2[..., indices] / lengthscale[ls_indices],
                         p=1),
            nu=2.5)


    def scalar_factorized_and_joint(names, suffix):
        w_additive = vp.symbol("w_additive" + suffix)
        alpha_factorized_or_joint = vp.symbol("alpha_factorized_or_joint"
                                              + suffix)
        state.allocate(w_additive, (len(names),),
                       zero_one_exclusive(),
                       ol.priors.DirichletPrior(torch.full((len(names),), 2.0)))
        state.allocate(alpha_factorized_or_joint, (),
                       zero_one_exclusive(),
                       ol.priors.BetaPrior(4.0, 1.0))
        return vtorch.sum(
            vctorch.heads_tails(alpha_factorized_or_joint)
            * vtorch.stack([
                vtorch.sum(
                    w_additive
                    * vtorch.stack([scalar_kernel([name])
                                    for name in names],
                                   dim=-1),
                    dim=-1),
                scalar_kernel(names),
            ], dim=-1),
            dim=-1
        )

    def regime_kernels():
        return [
            # kernel: regime choice parameters
            choice_kernel([f"{N_HOT_PREFIX}{i}"
                           for i in range(4)]),

            # kernel: lr schedule
            scalar_factorized_and_joint(
                ["log_1cycle_initial_lr", "log_1cycle_final_lr",
                 "log_1cycle_max_lr", "log_1cycle_pct_warmup"],
                "_lr"),

            # kernel: momentum schedule
            scalar_factorized_and_joint(
                ["log_1cycle_momentum_max_damping_factor",
                 "log_1cycle_momentum_min_damping_factor",
                 "log_1cycle_beta1_max_damping_factor",
                 "log_1cycle_beta1_min_damping_factor",
                 "log_beta2_damping_factor"],
                "_momentum"),

            # kernel: relative weight decay
            scalar_factorized_and_joint(
                ["log_conv1_wd_div_gmean", "log_conv2_wd_div_gmean",
                 "log_conv3_wd_div_gmean", "log_dense1_wd_div_gmean",
                 "log_dense2_wd_div_gmean"],
                "_wd"),
        ]

    regime_joint_names = ["log_epochs", "log_batch_size",
                          "log_gmean_weight_decay"]

    def architecture_kernels():
        return [
            # kernel: lr schedule
            scalar_factorized_and_joint(["log_conv1_channels_div_gmean",
                                         "log_conv2_channels_div_gmean",
                                         "log_conv3_channels_div_gmean",
                                         "log_dense1_units_div_gmean"],
                                        "_units_channels"),
        ]

    architecture_joint_names = ["log_gmean_channels_and_units"]

    regime_kernel = vctorch.fast_prod_positive(
        vtorch.stack(([scalar_kernel(regime_joint_names)]
                      + regime_kernels()),
                     dim=-1),
        dim=-1)
    architecture_kernel = vctorch.fast_prod_positive(
        vtorch.stack(([scalar_kernel(architecture_joint_names)]
                      + architecture_kernels()),
                     dim=-1),
        dim=-1)
    joint_kernel = vctorch.fast_prod_positive(
        vtorch.stack(([scalar_kernel(regime_joint_names
                                     + architecture_joint_names)]
                      + regime_kernels()
                      + architecture_kernels()),
                     dim=-1),
        dim=-1)

    alpha_regime_vs_architecture = vp.symbol("alpha_regime_vs_architecture")
    alpha_factorized_vs_joint = vp.symbol("alpha_factorized_vs_joint")
    scale = vp.symbol("scale")

    state.allocate(alpha_regime_vs_architecture, (),
                   zero_one_exclusive(),
                   ol.priors.BetaPrior(2.0, 2.0))
    state.allocate(alpha_factorized_vs_joint, (),
                   zero_one_exclusive(),
                   ol.priors.BetaPrior(4.0, 1.0))
    state.allocate(scale, (),
                   gpytorch.constraints.GreaterThan(1e-4),
                   gpytorch.priors.GammaPrior(2.0, 0.15))

    kernel = (scale
              * vtorch.sum(vctorch.heads_tails(alpha_factorized_vs_joint)
                           * vtorch.stack(
                               [vtorch.sum(
                                   vctorch.heads_tails(
                                       alpha_regime_vs_architecture)
                                   * vtorch.stack([
                                       regime_kernel,
                                       architecture_kernel
                                   ], dim=-1),
                                   dim=-1),
                                joint_kernel],
                               dim=-1),
                           dim=-1))

    state.allocate(lengthscale, (ialloc.count,),
                   gpytorch.constraints.GreaterThan(1e-4),
                   gpytorch.priors.GammaPrior(3.0, 6.0))

    return kernel, state.modules

You can read the rest of the code here.

Does it work?

My initial results are promising, but my baseline (the default BoTorch model) has some issues that might need to be ironed out.

I ran 160 semi-random MNIST LeNet training experiments on a 22-dimensional hyperparameter space. I fed the results to a Gaussian Process and checked its predictive power for held out experiments, performing leave-one-out cross-validation.

For a baseline one-size-fits-all model, I was hoping to just use Botorch’s default MixedSingleTaskGP off the shelf, but it does not use priors on its parameters and is hence prone to overfit and output extreme low log probabilities for held-out results. I added priors from one of their other models, and this partially solved the issue.

For the hands-on model, I designed a kernel off the top of my head, trying to capture some of my intuition about my hyperparameters. This was without any iteration, I just guessed at what kernel might work, chose priors that seemed sensible, and now I’m sharing my first results.

As a sanity-check, here is the raw data for cross-validation 60 of the 160 experiments. To generate each point, the model is trained on 59 configurations and outputs, then must predict the 60th output, given its configuration. (Click or zoom to see details.)

Here are the results in aggregate, using two different success metrics, and plotted for different subsets of the 160 experiments. I gathered this on 50 different shuffles of the data and I plot the mean.

On average, for small-to-medium datasets the “hands-on” model assigned higher probability density to the “correct” output for the held-out data. (It is important to use geometric mean, i.e. take the mean of the log probability, so that we are sure to penalize very low probabilities.) When using the model in MLE mode (maximum likelihood estimator) to output a single prediction, the hands-on model on average has lower prediction error.

Toward the end, the hands-on model falls behind the baseline. The MLE predictions aren’t obviously worse, but the mean log probability is, suggesting that hands-on model gives more confident predictions than it ought to. Thus, there is room for improvement. Still, it is promising that my first try was this good. I suspect a little bit of iteration on the kernel would lead to one that is superior for all dataset sizes.

Early in the chart, the baseline displays a strange phenomenon where it fails to improve as it receives more data. It is worth doing some due diligence here and try to understand what is going on. Maybe the baseline has a low-hanging possible improvement, and maybe that improvement could also be brought to the hands-on model.

My conclusions from this experiment are:

  • The initial results are promising, but not conclusive.
  • The weird baseline phenomenon illustrates my point that you should be careful treating a Bayesian Optimization as a black box. I suspect Bayesian Optimization is only worthwhile if we figure out how to make users take ownership of their search.

(This project is supported by a GCP cloud compute grant from ML Collective, which has been super helpful.)