import numpy as np
import matplotlib.pyplot as plt
from ax.service.ax_client import AxClient
from ax.modelbridge import (
Models,
)
from ax.modelbridge.generation_strategy import (
GenerationStep,
GenerationStrategy,
)
def sigmoid(x):
return 1 / (1 + np.exp(-x))
%config InlineBackend.figure_formats = ['svg']
ax_client1 = AxClient(generation_strategy=GenerationStrategy([GenerationStep(model=Models.GPEI, num_trials=-1)]),
verbose_logging=False)
ax_client1.create_experiment(
name="step_size_epochs_experiment",
parameters=[
{
"name": "num_epochs",
"type": "range",
"bounds": [0, 26],
"value_type": "int",
},
{
"name": "num_high_lr_epochs",
"type": "range",
"bounds": [0, 26],
"value_type": "int",
},
],
minimize=False,
objective_name="accuracy",
)
ax_client2 = AxClient(generation_strategy=GenerationStrategy([GenerationStep(model=Models.GPEI, num_trials=-1)]),
verbose_logging=False)
ax_client2.create_experiment(
name="step_size_epochs_experiment",
parameters=[
{
"name": "num_epochs",
"type": "range",
"bounds": [0, 26],
"value_type": "int",
},
{
"name": "num_low_lr_epochs",
"type": "range",
"bounds": [0, 2],
"value_type": "int",
},
],
minimize=False,
objective_name="accuracy",
)
[INFO 06-26 10:22:02] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='num_epochs', parameter_type=INT, range=[0, 26]), RangeParameter(name='num_high_lr_epochs', parameter_type=INT, range=[0, 26])], parameter_constraints=[]). [INFO 06-26 10:22:02] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='num_epochs', parameter_type=INT, range=[0, 26]), RangeParameter(name='num_low_lr_epochs', parameter_type=INT, range=[0, 2])], parameter_constraints=[]).
f = np.full((26, 26), np.nan, dtype=float)
f_observed = np.full((26, 26), np.nan, dtype=float)
g = np.full((26, 26), np.nan, dtype=float)
g_observed = np.full((26, 26), np.nan, dtype=float)
k = 1 / 8
observed_num_epochs = np.array([6, 12, 24])
EXACT = False
def raw_data(acc):
if EXACT:
return {"accuracy": (acc, 0.0)}
else:
return {"accuracy": acc}
for num_epochs in range(1, 26):
max_acc = (sigmoid(num_epochs * k) - 0.5) * 2
if num_epochs >= 2:
num_high_lr_epochs = num_epochs - 2
num_low_lr_epochs = num_epochs - num_high_lr_epochs
acc = max_acc * 0.5
f[num_high_lr_epochs, num_epochs] = acc
g[num_low_lr_epochs, num_epochs] = acc
if num_epochs in observed_num_epochs:
f_observed[num_high_lr_epochs, num_epochs] = acc
g_observed[num_low_lr_epochs, num_epochs] = acc
config = dict(num_epochs=num_epochs, num_high_lr_epochs=num_high_lr_epochs)
_, trial_index = ax_client1.attach_trial(config)
ax_client1.complete_trial(trial_index=trial_index,
raw_data=raw_data(acc))
config = dict(num_epochs=num_epochs, num_low_lr_epochs=num_low_lr_epochs)
_, trial_index = ax_client2.attach_trial(config)
ax_client2.complete_trial(trial_index=trial_index,
raw_data=raw_data(acc))
if num_epochs >= 1:
num_high_lr_epochs = num_epochs - 1
num_low_lr_epochs = num_epochs - num_high_lr_epochs
acc = max_acc
f[num_high_lr_epochs, num_epochs] = acc
g[num_low_lr_epochs, num_epochs] = acc
if num_epochs in observed_num_epochs:
f_observed[num_high_lr_epochs, num_epochs] = acc
g_observed[num_epochs - num_high_lr_epochs, num_epochs] = acc
config = dict(num_epochs=num_epochs, num_high_lr_epochs=num_high_lr_epochs)
_, trial_index = ax_client1.attach_trial(config)
ax_client1.complete_trial(trial_index=trial_index,
raw_data=raw_data(acc))
config = dict(num_epochs=num_epochs, num_low_lr_epochs=num_low_lr_epochs)
_, trial_index = ax_client2.attach_trial(config)
ax_client2.complete_trial(trial_index=trial_index,
raw_data=raw_data(acc))
num_high_lr_epochs = num_epochs
num_low_lr_epochs = num_epochs - num_high_lr_epochs
acc = max_acc * 0.5
f[num_high_lr_epochs, num_epochs] = acc
g[num_low_lr_epochs, num_epochs] = acc
if num_epochs in observed_num_epochs:
f_observed[num_high_lr_epochs, num_epochs] = acc
g_observed[num_low_lr_epochs, num_epochs] = acc
config = dict(num_epochs=num_epochs, num_high_lr_epochs=num_high_lr_epochs)
_, trial_index = ax_client1.attach_trial(config)
ax_client1.complete_trial(trial_index=trial_index,
raw_data=raw_data(acc))
config = dict(num_epochs=num_epochs, num_low_lr_epochs=num_low_lr_epochs)
_, trial_index = ax_client2.attach_trial(config)
ax_client2.complete_trial(trial_index=trial_index,
raw_data=raw_data(acc))
# Force the client to build a model
_, trial_index = ax_client1.get_next_trial()
ax_client1.abandon_trial(trial_index)
_, trial_index = ax_client2.get_next_trial()
ax_client2.abandon_trial(trial_index)
fig, ax = plt.subplots(1, 2, figsize=(8,4), gridspec_kw={'width_ratios': [3, 1]}
)
ax[0].imshow(f.T, origin="lower", cmap="Reds", vmin=0, vmax=1)
ax[0].set_ylabel("total # of epochs")
ax[0].set_xlabel("# of epochs at high learning rate")
ax[0].set_xlim(-0.5, 25.5)
ax[0].set_ylim(-0.5, 25.5)
ax[0].set_title("Parameterization 1")
pos = ax[1].imshow(g.T, origin="lower", cmap="Reds", vmin=0, vmax=1)
ax[1].set_ylabel("total # of epochs")
ax[1].set_xlabel("# of epochs at low learning rate")
ax[1].set_xticks([0, 1, 2])
ax[1].set_ylim(-0.5, 25.5)
ax[1].set_xlim(-0.5, 2.5)
ax[1].set_title("Parameterization 2")
clb = fig.colorbar(pos, shrink=0.5, label="Accuracy", ax=ax, location="right")
#plt.tight_layout()
plt.savefig("two-parameterizations.svg")
plt.show()
plt.figure(figsize=(8,6))
plt.subplot(1, 2, 1)
plt.imshow(f_observed.T, origin="lower", cmap="Reds", vmin=0, vmax=1)
plt.ylabel("total # of epochs")
plt.xlabel("# of epochs at high learning rate")
plt.ylim(-0.5, 25.5)
plt.xlim(-0.5, 25.5)
plt.title("Observed function")
plt.subplot(3, 2, 2)
all_num_epochs = np.arange(1, 26)
all_num_high_lr_epochs = all_num_epochs
accs = f[all_num_high_lr_epochs, all_num_epochs]
predictions = []
errors = []
for num_epochs, prediction in ax_client1.get_model_predictions(
parameterizations={num_epochs: dict(num_epochs=num_epochs, num_high_lr_epochs=num_epochs)
for num_epochs in all_num_epochs}
).items():
mean_acc, err = prediction["accuracy"]
predictions.append(mean_acc)
errors.append(err)
plt.ylim(0.0, 1.0)
plt.errorbar(all_num_epochs, predictions, yerr=errors, fmt=".", zorder=5, label="predictions")
plt.plot(all_num_epochs, accs, "--", label="actual function")
plt.plot(observed_num_epochs, g[0, observed_num_epochs], "o", color="C1", zorder=10, label="observations")
plt.xlabel("total # of epochs")
plt.ylabel("accuracy")
plt.title("Predictions: 0 epochs at low learning rate")
plt.legend()
plt.subplot(3, 2, 4)
all_num_epochs = np.arange(1, 26)
all_num_high_lr_epochs = all_num_epochs - 1
accs = f[all_num_high_lr_epochs, all_num_epochs]
predictions = []
errors = []
for num_epochs, prediction in ax_client1.get_model_predictions(
parameterizations={num_epochs: dict(num_epochs=num_epochs, num_high_lr_epochs=num_epochs - 1)
for num_epochs in all_num_epochs}
).items():
mean_acc, err = prediction["accuracy"]
predictions.append(mean_acc)
errors.append(err)
plt.ylim(0.0, 1.0)
plt.errorbar(all_num_epochs, predictions, yerr=errors, fmt=".", zorder=5)
plt.plot(all_num_epochs, accs, "--")
plt.plot(observed_num_epochs, g[1, observed_num_epochs], "o", color="C1", zorder=10)
plt.xlabel("total # of epochs")
plt.ylabel("accuracy")
plt.title("Predictions: 1 epoch at low learning rate")
plt.subplot(3, 2, 6)
all_num_epochs = np.arange(2, 26)
all_num_high_lr_epochs = all_num_epochs - 2
accs = f[all_num_high_lr_epochs, all_num_epochs]
predictions = []
errors = []
for num_epochs, prediction in ax_client1.get_model_predictions(
parameterizations={num_epochs: dict(num_epochs=num_epochs, num_high_lr_epochs=num_epochs - 2)
for num_epochs in all_num_epochs}
).items():
mean_acc, err = prediction["accuracy"]
predictions.append(mean_acc)
errors.append(err)
plt.ylim(0.0, 1.0)
plt.errorbar(all_num_epochs, predictions, yerr=errors, fmt=".", label="predictions")
plt.plot(all_num_epochs, accs, "--", color="C1", zorder=5, label="actual function")
plt.plot(observed_num_epochs, g[2, observed_num_epochs], "o", color="C1", zorder=10, label="observations")
plt.xlabel("total # of epochs")
plt.ylabel("accuracy")
plt.title("Predictions: 2 epochs at low learning rate")
plt.tight_layout()
plt.savefig("diagonal-predictions.svg")
plt.show()
plt.figure(figsize=(8,6))
plt.subplot(1, 2, 1)
plt.imshow(g_observed.T, origin="lower", cmap="Reds", vmin=0, vmax=1)
plt.ylabel("total # of epochs")
plt.xlabel("# of epochs at low learning rate")
plt.xticks([0, 1, 2])
plt.ylim(-0.5, 25.5)
plt.xlim(-0.5, 2.5)
plt.title("Observed function")
plt.subplot(3, 2, 2)
all_num_epochs = np.arange(1, 26)
all_num_low_lr_epochs = [0] * 25
accs = g[all_num_low_lr_epochs, all_num_epochs]
predictions = []
errors = []
for num_epochs, prediction in ax_client2.get_model_predictions(
parameterizations={num_epochs: dict(num_epochs=num_epochs, num_low_lr_epochs=0)
for num_epochs in all_num_epochs}
).items():
mean_acc, err = prediction["accuracy"]
predictions.append(mean_acc)
errors.append(err)
plt.ylim(0.0, 1.0)
plt.errorbar(all_num_epochs, predictions, yerr=errors, fmt=".", zorder=5, label="predictions")
plt.plot(all_num_epochs, accs, "--", label="actual function")
plt.plot(observed_num_epochs, g[0, observed_num_epochs], "o", color="C1", zorder=10, label="observations")
plt.xlabel("total # of epochs")
plt.ylabel("accuracy")
plt.title("Predictions: 0 epochs at low learning rate")
plt.legend()
plt.subplot(3, 2, 4)
all_num_epochs = np.arange(1, 26)
all_num_low_lr_epochs = [1] * 25
accs = g[all_num_low_lr_epochs, all_num_epochs]
predictions = []
errors = []
for num_epochs, prediction in ax_client2.get_model_predictions(
parameterizations={num_epochs: dict(num_epochs=num_epochs, num_low_lr_epochs=1)
for num_epochs in all_num_epochs}
).items():
mean_acc, err = prediction["accuracy"]
predictions.append(mean_acc)
errors.append(err)
plt.ylim(0.0, 1.0)
plt.errorbar(all_num_epochs, predictions, yerr=errors, fmt=".", zorder=5)
plt.plot(all_num_epochs, accs, "--")
plt.plot(observed_num_epochs, g[1, observed_num_epochs], "o", color="C1", zorder=10)
plt.xlabel("total # of epochs")
plt.ylabel("accuracy")
plt.title("Predictions: 1 epoch at low learning rate")
plt.subplot(3, 2, 6)
all_num_epochs = np.arange(2, 26)
all_num_low_lr_epochs = [2] * 24
accs = g[all_num_low_lr_epochs, all_num_epochs]
predictions = []
errors = []
for num_epochs, prediction in ax_client2.get_model_predictions(
parameterizations={num_epochs: dict(num_epochs=num_epochs, num_low_lr_epochs=2)
for num_epochs in all_num_epochs}
).items():
mean_acc, err = prediction["accuracy"]
predictions.append(mean_acc)
errors.append(err)
plt.ylim(0.0, 1.0)
plt.errorbar(all_num_epochs, predictions, yerr=errors, fmt=".", label="predictions")
plt.plot(all_num_epochs, accs, "--", color="C1", zorder=5, label="actual function")
plt.plot(observed_num_epochs, g[2, observed_num_epochs], "o", color="C1", zorder=10, label="observations")
plt.xlabel("total # of epochs")
plt.ylabel("accuracy")
plt.title("Predictions: 2 epochs at low learning rate")
plt.tight_layout()
plt.savefig("vertical-predictions.svg")
plt.show()
fig, ax = plt.subplots(1, 2, figsize=(8,4), gridspec_kw={'width_ratios': [3, 1]}
)
ax[0].imshow(f_observed.T, origin="lower", cmap="Reds", vmin=0, vmax=1)
ax[0].set_ylabel("total # of epochs")
ax[0].set_xlabel("# of epochs at high learning rate")
ax[0].set_xlim(-0.5, 25.5)
ax[0].set_ylim(-0.5, 25.5)
ax[0].set_title("Parameterization 1")
pos = ax[1].imshow(g_observed.T, origin="lower", cmap="Reds", vmin=0, vmax=1)
ax[1].set_ylabel("total # of epochs")
ax[1].set_xlabel("# of epochs at low learning rate")
ax[1].set_xticks([0, 1, 2])
ax[1].set_ylim(-0.5, 25.5)
ax[1].set_xlim(-0.5, 2.5)
ax[1].set_title("Parameterization 2")
#plt.tight_layout()
plt.savefig("two-parameterizations-observed.svg")
plt.show()