Skip to content

Commit

Permalink
add tests for vbll model
Browse files Browse the repository at this point in the history
  • Loading branch information
brunzema committed Feb 21, 2025
1 parent 1b0f899 commit 2606ac3
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 15 deletions.
54 changes: 39 additions & 15 deletions botorch_community/models/vblls.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class VBLLNetwork(nn.Module):
def __init__(
self,
in_features: int = 2,
hidden_features: int = 50,
hidden_features: int = 64,
out_features: int = 1,
num_layers: int = 3,
prior_scale: float = 1.0,
Expand Down Expand Up @@ -206,6 +206,10 @@ def num_inputs(self):
def device(self):
return self.model.device

@property
def backbone(self):
return self.model.backbone

def fit(
self,
train_X: Tensor,
Expand Down Expand Up @@ -308,12 +312,13 @@ def fit(
early_stop = False
best_model_state = None # To store the best model parameters

self.model.train()

for epoch in range(1, optimization_settings["num_epochs"] + 1):
# early stopping
if early_stop:
break

self.model.train()
running_loss = []

for train_step, (x, y) in enumerate(dataloader):
Expand All @@ -323,9 +328,12 @@ def fit(
loss = out.train_loss_fn(y) # vbll layer will calculate the loss

loss.backward()
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), optimization_settings["clip_val"]
)

if optimization_settings["clip_val"] is not None:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), optimization_settings["clip_val"]
)

optimizer.step()
running_loss.append(loss.item())

Expand Down Expand Up @@ -355,26 +363,42 @@ def posterior(
self,
X: Tensor,
output_indices=None,
observation_noise=False,
observation_noise=None,
posterior_transform=None,
) -> Posterior:
if len(X.shape) < 3:
B, D = X.shape
Q = 1
# Determine if the input is batched
batched = X.dim() == 3

if not batched:
N, D = X.shape
B = 1
else:
B, Q, D = X.shape
X = X.reshape(B * Q, D)
B, N, D = X.shape
X = X.reshape(B * N, D)

posterior = self.model(X).predictive

# Extract mean and variance
mean = posterior.mean.squeeze(-1)
variance = posterior.variance.squeeze(-1)
mean = posterior.mean.squeeze()
variance = posterior.variance.squeeze()
cov = torch.diag_embed(variance)

K = self.num_outputs
mean = mean.reshape(B, N * K)

# Cov must be `(B, Q*K, Q*K)`
cov = cov.reshape(B, N, K, B, N, K)
cov = torch.einsum("bqkbrl->bqkrl", cov) # (B, Q, K, Q, K)
cov = cov.reshape(B, N * K, N * K)

# Remove fake batch dimension if not batched
if not batched:
mean = mean.squeeze(0)
cov = cov.squeeze(0)

# pass as MultivariateNormal to GPyTorchPosterior
dist = MultivariateNormal(mean, cov)
post_pred = GPyTorchPosterior(dist)
mvn_dist = MultivariateNormal(mean, cov)
post_pred = GPyTorchPosterior(mvn_dist)
return BLLPosterior(post_pred, self, X, self.num_outputs)

@abstractmethod
Expand Down
186 changes: 186 additions & 0 deletions test_community/models/test_vblls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import copy

import torch
from botorch.utils.testing import BotorchTestCase
from botorch_community.models.vblls import VBLLModel
from botorch_community.posteriors.bll_posterior import BLLPosterior


def _reg_data_singletask(d, n=10):
X = torch.randn(10, d)
y = torch.randn(10, 1)
return X, y


def _get_fast_training_settings():
return {
"num_epochs": 3,
"lr": 0.01,
}


class TestVBLLModel(BotorchTestCase):
def test_initialization(self) -> None:
d, num_hidden, num_outputs, num_layers = 2, 3, 1, 4
model = VBLLModel(
in_features=d,
hidden_features=num_hidden,
num_layers=num_layers,
out_features=num_outputs,
)
self.assertEqual(model.num_inputs, d)
self.assertEqual(model.num_outputs, num_outputs)

hidden_layer_count = sum(
isinstance(layer, torch.nn.Linear)
for submodule in model.backbone[1:] # note that the first layer is excluded
for layer in (
submodule if isinstance(submodule, torch.nn.Sequential) else [submodule]
)
)
self.assertEqual(
hidden_layer_count,
num_layers,
f"Expected {num_layers} hidden layers, but got {hidden_layer_count}.",
)

def test_backbone_initialization(self) -> None:
d, num_hidden = 4, 3
test_backbone = torch.nn.Sequential(
torch.nn.Linear(d, num_hidden),
torch.nn.ReLU(),
torch.nn.Linear(num_hidden, num_hidden),
)
model = VBLLModel(backbone=test_backbone, hidden_features=num_hidden)

for key in test_backbone.state_dict():
self.assertTrue(
torch.allclose(
test_backbone.state_dict()[key],
model.backbone.state_dict()[key],
atol=1e-6,
),
f"Mismatch of backbone state_dict for key: {key}",
)

def test_freezing_backbone(self) -> None:
d, num_hidden = 4, 3
for freeze_backbone in (True, False):
test_backbone = torch.nn.Sequential(
torch.nn.Linear(d, num_hidden),
torch.nn.ReLU(),
torch.nn.Linear(num_hidden, num_hidden),
torch.nn.ELU(),
)

model = VBLLModel(
backbone=copy.deepcopy(test_backbone), # copy.deepcopy(test_backbone)
hidden_features=num_hidden, # match the output of the backbone
)

X, y = _reg_data_singletask(d)
optim_settings = {
"num_epochs": 10,
"lr": 5.0, # large lr to make sure that the weights change
"freeze_backbone": freeze_backbone,
}
model.fit(X, y, optimization_settings=optim_settings)

if freeze_backbone:
# Ensure all parameters remain unchanged
all_params_unchanged = all(
torch.allclose(
test_backbone.state_dict()[key],
model.backbone.state_dict()[key],
atol=1e-6,
)
for key in test_backbone.state_dict()
)
self.assertTrue(
all_params_unchanged,
f"Expected all parameters to remain unchanged, but some changed with freeze_backbone={freeze_backbone}",
)
else:
# Ensure at least one parameter has changed
any_param_changed = any(
not torch.allclose(
test_backbone.state_dict()[key],
model.backbone.state_dict()[key],
atol=1e-6,
)
for key in test_backbone.state_dict()
)
self.assertTrue(
any_param_changed,
f"Expected at least one parameter to change, but all remained the same with freeze_backbone={freeze_backbone}",
)

def test_update_of_reg_weight(self) -> None:
kl_scale = 2.0
d = 2
model = VBLLModel(
in_features=d,
hidden_features=3,
out_features=1,
num_layers=1,
kl_scale=kl_scale,
)
self.assertEqual(
model.model.head.regularization_weight,
1.0,
"Regularization weight should be 1.0 after init.",
)

X, y = _reg_data_singletask(d)

optim_settings = _get_fast_training_settings()
model.fit(X, y, optimization_settings=optim_settings)

self.assertEqual(
model.model.head.regularization_weight,
kl_scale / len(y),
f"Regularization weight should be {kl_scale}/{len(y)}, but got {model.model.head.regularization_weight}.",
)

def test_shape_of_predictions(self) -> None:
d = 4
model = VBLLModel(
in_features=d, hidden_features=4, out_features=1, num_layers=1
)
X, y = _reg_data_singletask(d)
optim_settings = _get_fast_training_settings()

model.fit(X, y, optimization_settings=optim_settings)

for batch_shape in (torch.Size([2]), torch.Size()):
X = torch.rand(batch_shape + torch.Size([3, d]))
expected_shape = batch_shape + torch.Size([3, 1])

post = model.posterior(X)

# check that the posterior is an instance of BLLPosterior
self.assertIsInstance(
post,
BLLPosterior,
"Expected posterior to be an instance of BLLPosterior.",
)

# mean prediction
self.assertEqual(
post.mean.shape,
expected_shape,
f"Expected mean predictions to have shape {expected_shape}, but got {post.mean.shape}.",
)

# variance prediction
self.assertEqual(
post.variance.shape,
expected_shape,
f"Expected variance predictions to have shape {expected_shape}, but got {post.mean.shape}.",
)

0 comments on commit 2606ac3

Please sign in to comment.