Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: optimize_acqf_mixed does not respect parameter constraint with fixed features #2708

Open
1 task done
ysellai opened this issue Jan 28, 2025 · 4 comments
Open
1 task done
Labels
bug Something isn't working

Comments

@ysellai
Copy link

ysellai commented Jan 28, 2025

Hi everyone, I am facing an issue using inequality constraints and fixed features.


Issue Description

When using optimize_acqf_mixed in BoTorch with inequality constraints and fixed features, I encounter inconsistent behavior:

  1. Error with CandidateGenerationError:
    In some cases, I get the following error:

    CandidateGenerationError: Inequality constraint 2 not met with fixed_features.
    

    This indicates that the candidate generation failed because no valid point could satisfy both the constraints and the fixed features.

  2. Warning from scipy.optimize.minimize:
    In other cases, I see the following warning:

    OptimizationWarning('Optimization failed within `scipy.optimize.minimize` with status 8 and message Positive directional derivative for linesearch.')
    Trying again with a new set of initial conditions.
    /path/to/optimize.py:568: RuntimeWarning: Optimization failed on the second try, after generating a new set of initial conditions.
    return _optimize_acqf_batch(opt_inputs=opt_inputs)
    

    Despite the warning, a candidate is generated.

  3. No warning but invalid candidates:
    Occasionally, no warning or error is raised, but the generated candidate violates one or more of the inequality constraints. For example:

    Constraints:

    inequality_constraints = [
        (torch.tensor([0, 1, 2]), torch.tensor([1.0, 1.0, 1.0]), 5.4),  # TA >= 5.4
        (torch.tensor([0, 1, 2]), torch.tensor([-1.0, -1.0, -1.0]), -10.6),  # TA <= 10.6
        (torch.tensor([3, 4]), torch.tensor([1.0, 1.0]), 15.4),  # S >= 15.4
        (torch.tensor([3, 4]), torch.tensor([-1.0, -1.0]), -20.1),  # S <= 20.1
        (torch.tensor([5, 6]), torch.tensor([1.096, 1.802]), 7.4),  # C >= 7.4
        (torch.tensor([5, 6]), torch.tensor([-1.096, -1.802]), -9.6),  # C <= 9.6
    ]

    Generated candidate violating constraint 3 (S >= 15.4):

    Candidate: tensor([4.9535, 1.7567, 0.4469, 0.0000, 0.5659, 3.9041, 1.7311, 11.3403])
    Constraint 3 (S >= 15.4): FAIL, Rounded Sum = 0.5659
    
  4. Another weird behaviour:
    With fixed_feature =
    [{0: 0.0, 1: 0.0},
    {0: 0.0, 2: 0.0},
    {0: 0.0, 3: 0.0},
    {0: 0.0, 4: 0.0},
    {1: 0.0, 2: 0.0},
    {1: 0.0, 3: 0.0},
    {1: 0.0, 4: 0.0},
    {2: 0.0, 3: 0.0},
    {2: 0.0, 4: 0.0},
    {3: 0.0, 4: 0.0}]
    Error: CandidateGenerationError: Inequality constraint 2 not met with fixed_features.
    In the first case, with this fixe feature listfixed feature should (so one of them) {0: 0.0, 4: 0.0} works but BoTorch return the error.
    In the second case, with only this fixed feature {1: 0.0, 4: 0.0} in the list, It finds a point that respect constraint


Expected Behavior

The optimization should:

  • Generate candidates that respect all inequality constraints and fixed features.
  • Fail with a clear error (CandidateGenerationError) when no valid candidates can be generated.

Actual Behavior

  1. Inconsistent handling of constraints:
    • Sometimes BoTorch raises "CandidateGenerationError: Inequality constraint number ** not met with the fixed features: Sometimes it is True, and sometimes a fixed feature could respect constraint but this error is raised in any case.
    • Sometimes BoTorch issues a warning from scipy.optimize.minimize but proceeds to generate candidates.
    • Sometimes candidates that violate constraints are generated without any warning or error.

I understand that in 2022, BoTorch can send back a point that doesn't respect the constraints because of scipy. But I don't know if this has been repaired and I can't explain the other behavior.

Your help will be much appreciated

Please provide a minimal, reproducible example of the unexpected behavior.

Steps to Reproduce

Here is a code to reproduce this case:
With fixed_feature = [{0: 0.0, 1: 0.0},
{0: 0.0, 2: 0.0},
{0: 0.0, 3: 0.0},
{0: 0.0, 4: 0.0},
{1: 0.0, 2: 0.0},
{1: 0.0, 3: 0.0},
{1: 0.0, 4: 0.0},
{2: 0.0, 3: 0.0},
{2: 0.0, 4: 0.0},
{3: 0.0, 4: 0.0}] Error: CandidateGenerationError: Inequality constraint 2 not met with fixed_features.
In this se case, fixed feature should (so one of them) {0: 0.0, 4: 0.0} works but BoTorch return the error.
In the second case, with only this fixed feature {1: 0.0, 4: 0.0}, It finds a point that respect constraint

  1. Define inequality constraints as shown above.
  2. Use optimize_acqf_mixed with a set of fixed features, like:
    fixed_features_list = [{0: 0.0}, {1: 0.0}]
  3. Run the optimization with a qExpectedImprovement acquisition function and parameters like:
     import torch
     from botorch.models import SingleTaskGP
     from botorch.fit import fit_gpytorch_mll
     from botorch.acquisition import qExpectedImprovement
     from botorch.optim import optimize_acqf_mixed
     from gpytorch.mlls import ExactMarginalLogLikelihood
     
     # Define the target function based on `model.predict`.
     def target_function(individual):
         # Convert points to a format compatible with model.predict
         individual = np.array(individual)
         
         # Predict values with prediction model (assume `model` is on CPU or GPU as required)
         predicted_value = model.predict(individual)
         
         # Convert to PyTorch tensor and move to device
         return torch.tensor(-abs(predicted_value - target_value), dtype=torch.float64), predicted_value
     
     def check_constraints(candidate):
         """
         Vérifie si un candidat respecte les contraintes et affiche les résultats.
     
         Args:
             candidate (torch.Tensor): Un candidat généré.
     
         Returns:
             None
         """
         # Extraire les valeurs des groupes
         TA_sum = candidate[0:3].sum().item()  # TA1 + TA2 + TA3
         S_sum = candidate[3:5].sum().item()   # S1 + S2
         C_sum = (candidate[5] / 0.912 + candidate[6] / 0.555).item()  # C1/0.912 + C2/0.555
     
         # Affichage des résultats
         print(f"TA Sum: {TA_sum:.2f} (Constraint: 5.4 <= sum <= 10.6)")
         print(f"S Sum: {S_sum:.2f} (Constraint: 15.4 <= sum <= 20.1)")
         print(f"C Sum: {C_sum:.2f} (Constraint: 7.4 <= sum <= 9.6)")
         print(f"TA Valid: {5.4 <= TA_sum <= 10.6}")
         print(f"S Valid: {15.4 <= S_sum <= 20.1}")
         print(f"C Valid: {7.4 <= C_sum <= 9.6}\n")
     
     # Define the function for generating the next points with `optimize_acqf_mixed`
     def get_next_points_EI_with_mixed(init_x, init_y, best_init_y, bounds, n_points, fixed_features_list):
         """
         Perform optimization using `optimize_acqf_mixed` with fixed_features_list.
     
         Args:
             init_x (torch.Tensor): Initial design points.
             init_y (torch.Tensor): Corresponding initial objective values.
             best_init_y (float): Current best objective value.
             bounds (torch.Tensor): Bounds for the optimization.
             n_points (int): Number of points to propose.
             fixed_features_list (list): List of fixed features dictionaries.
     
         Returns:
             torch.Tensor: The best proposed candidate.
         """
         # Initialize the GP model
         gp = SingleTaskGP(init_x, init_y)
         mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
         fit_gpytorch_mll(mll)
     
         EI = qExpectedImprovement(model=gp, best_f=best_init_y)
     
         # Define constraints
         indices_TA = torch.tensor([0, 1, 2], dtype=torch.long)  # TA1, TA2, TA3
         coefficients_TA = torch.tensor([1.0, 1.0, 1.0], dtype=torch.float64)
     
         indices_S = torch.tensor([3, 4], dtype=torch.long)  # S1, S2
         coefficients_S = torch.tensor([1.0, 1.0], dtype=torch.float64)
     
         indices_C = torch.tensor([5, 6], dtype=torch.long)  # C1, C2
         coefficients_C = torch.tensor([1 / 0.912, 1 / 0.555], dtype=torch.float64)
     
         # Constraints reformulated for optimize_acqf
         inequality_constraints = [
             (indices_TA, coefficients_TA, 5.4),
             (indices_TA, -coefficients_TA, -10.6),
             (indices_S, coefficients_S, 15.4),
             (indices_S, -coefficients_S, -20.1),
             (indices_C, coefficients_C, 7.4),
             (indices_C, -coefficients_C, -9.6),
         ]
     
         # Use optimize_acqf_mixed with fixed features
         candidates, _ = optimize_acqf_mixed(
             acq_function=EI,
             bounds=bounds,
             fixed_features_list=fixed_features_list,
             q=n_points,
             num_restarts=5,
             raw_samples=20,
             inequality_constraints=inequality_constraints,
             options={
                 "batch_limit": 64,
                 "maxiter": 200,
             }
         )
     
         print("candidates: ", candidates)
     
         return candidates
     
     # Generate LHS points with constraints (at least 2 variables among the first 5 are 0)
     def generate_lhs_with_constraints(space, num_points, seed):
         """
         Generate LHS points while ensuring at least 2 variables among the first 5 are 0.
     
         Args:
             space (array-like): Bounds for the variables.
             num_points (int): Number of points to generate.
             seed (int): Seed for reproducibility.
     
         Returns:
             np.ndarray: Generated LHS points.
             np.ndarray: Corresponding objective values (placeholder).
         """
         from pyDOE import lhs
         import numpy as np
     
         np.random.seed(seed)
         n_dims = len(space)
     
         # Generate initial LHS samples
         lhs_samples = lhs(n_dims, samples=num_points)
         scaled_samples = np.zeros_like(lhs_samples)
     
         for i, bounds in enumerate(space):
             scaled_samples[:, i] = lhs_samples[:, i] * (bounds[1] - bounds[0]) + bounds[0]
     
         # Apply the constraint: At least 2 variables among the first 5 are 0
         for row in scaled_samples:
             zero_indices = np.random.choice(5, size=2, replace=False)  # Select 2 indices to set to 0
             row[zero_indices] = 0  # Set the selected variables to 0
     
         # Predict the values for each point using your model
         prediction_values = model.predict(scaled_samples).reshape(1, -1)[0]
     
         objective_values = -np.abs(prediction_values - target_value).reshape(-1, 1)
         
         return scaled_samples, objective_values
     
     torch.manual_seed(42)
     torch.set_num_threads(os.cpu_count())
     
     # Extract the search space (min and max for each variable)
     # space = np.array([[data_processed[col].min(), data_processed[col].max()] for col in X_columns])
     space = np.array([[ 0.   ,  5.5  ],
            [ 0.   ,  5.5  ],
            [ 0.   ,  5.5  ],
            [14.5  , 19.   ],
            [ 0.   ,  3.5  ],
            [ 0.   ,  7.5  ],
            [ 0.   ,  4.7  ],
            [ 8.795, 16.77 ]])
     
     num_points = len(space) + 1
     seed = 42
     max_iter = 50
     
     # Generate LHS points
     initial_points_EI, initial_responses_EI = generate_lhs_with_constraints(space, num_points, seed)
     
     # Convert initial responses for PyTorch optimization
     initial_objective_values_EI = initial_responses_EI.reshape(-1, 1)
     init_x_EI = torch.tensor(initial_points_EI, dtype=torch.float64)
     init_y_EI = torch.tensor(initial_objective_values_EI, dtype=torch.float64)
     
     # Define bounds for optimization
     bounds_EI = torch.tensor(space.T, dtype=torch.float64)
     
     # Define fixed features for pairs of dimensions (5 first dimensions)
     fixed_features_list = [{i: 0.0, j: 0.0} for i in range(5) for j in range(i + 1, 5)]
     # fixed_features_list = [{0: 1.0, 4: 0.0}]
     
     
     # Initial best value
     best_init_y_EI = init_y_EI.max().item()
     
     # Optimization loop
     responses_over_time_EI = []
     objective_values_EI = []
     candidates_EI = []
     candidates_prediction_EI = []
     
     for i in range(max_iter):
         print(f"Optimization Run: {i + 1}")
     
         # Generate new candidates
         new_candidate_EI = get_next_points_EI_with_mixed(
             init_x_EI,
             init_y_EI,
             best_init_y_EI,
             bounds_EI,
             n_points=1,
             fixed_features_list=fixed_features_list,
         )
     
         for candidate in new_candidate_EI:
             print("Checking constraints for candidate:")
             check_constraints(candidate)
             
         new_result_EI, predicted_values_EI = target_function(new_candidate_EI)
         new_result_EI = new_result_EI.unsqueeze(-1)
     
         # Store responses
         candidates_EI.append(new_candidate_EI)
         candidates_prediction_EI.append(predicted_values_EI)
         responses_over_time_EI.append(new_result_EI[0].item())
         objective_values_EI.append(-new_result_EI[0].item())
     
         # Update the data
         init_x_EI = torch.cat([init_x_EI, new_candidate_EI])
         init_y_EI = torch.cat([init_y_EI, new_result_EI])
     
         # Update best value
         best_init_y_EI, best_index_EI = init_y_EI.max(0)
         best_point_EI = init_x_EI[best_index_EI]
     
         print(f"New Candidate abs Response: {new_result_EI[0].item()}")
         print(f"New Candidate predicted Response: {predicted_values_EI}")
         print(f"Best global Point Performance: {best_init_y_EI.item()}")
         print(f"Best algorithm Point Performance: {max(responses_over_time_EI)} at iteration {np.argmax(responses_over_time_EI) + 1}")

Please paste any relevant traceback/logs produced by the example provided.

BoTorch Version

0.12.0

Python Version

3.12

Operating System

Linux

Code of Conduct

  • I agree to follow BoTorch's Code of Conduct
@ysellai ysellai added the bug Something isn't working label Jan 28, 2025
@Balandat
Copy link
Contributor

Thanks for raising this issue. I'm having some trouble with the reproducible example though, generate_lhs_with_constraints is supposed to use a model but doesn't have access to one so this just throws an error. Before I take a wild guess what should be going on, could you update the example so it's end-to-end runnable? Thanks!

@Balandat Balandat changed the title [Bug]: optimize_acqf Does not respect constraint [Bug]: optimize_acqf_mixed does not respect parameter constraint with fixed features Jan 29, 2025
@ysellai
Copy link
Author

ysellai commented Jan 29, 2025

Hi, my bad. Here is the right code without using the prediction model.

I modified my code to test each element of fixed_features_list individually in order to eliminate fixed features that are incompatible with the constraints. In this code, I just kept one fixed_feature (one that does not respect constraints on S variables.

With optimize_acqf_mixed, if a fixed feature violated a constraint, the optimization would raise an error, allowing me to identify and filter it. I also noticed a bug related to this behavior (see the pending PR: #2614), but my current code does not yet incorporate this fix.

🚨 Main issue:
When a fixed feature is incompatible with a constraint, BoTorch returns a warning (outside of a notebook):

RuntimeWarning: Optimization failed in gen_candidates_scipywith the following warning(s): [OptimizationWarning('Optimization failed withinscipy.optimize.minimizewith status 8 and message Positive directional derivative for linesearch.')] Trying again with a new set of initial conditions.
However, despite this warning, BoTorch still returns a point that does not respect the constraint.

What I understand:
BoTorch does not enter the try-except block because it has not mathematically proven that the fixed feature is infeasible.
Scipy tries multiple initializations and eventually returns a point, even if it does not satisfy the constraint.
What I want:
👉 If Scipy detects that a fixed feature does not satisfy the constraint, then it should not return any point at all.
This would prevent me from having to manually check each generated point to ensure that it respects the constraint.

🔹 Another concern:
Even if a fixed feature is compatible with a constraint, I cannot be certain that BoTorch always generates a valid point. I need a solution where BoTorch never returns a point that violates constraints, regardless of whether the fixed feature is feasible or not.

import torch
from botorch.models import SingleTaskGP
from botorch.fit import fit_gpytorch_mll
from botorch.acquisition import qExpectedImprovement
from botorch.optim import optimize_acqf
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.exceptions import CandidateGenerationError

def generate_lhs_with_constraints(space, num_points, seed):
    """
    Generate LHS points while ensuring at least 2 variables among the first 5 are 0.

    Args:
        space (array-like): Bounds for the variables.
        num_points (int): Number of points to generate.
        seed (int): Seed for reproducibility.

    Returns:
        np.ndarray: Generated LHS points.
        np.ndarray: Random objective values between 0 and 1.
    """
    np.random.seed(seed)
    n_dims = len(space)

    # Generate initial LHS samples
    lhs_samples = lhs(n_dims, samples=num_points)
    scaled_samples = np.zeros_like(lhs_samples)

    for i, bounds in enumerate(space):
        scaled_samples[:, i] = lhs_samples[:, i] * (bounds[1] - bounds[0]) + bounds[0]

    # Apply the constraint: At least 2 variables among the first 5 are 0
    for row in scaled_samples:
        zero_indices = np.random.choice(5, size=2, replace=False)  # Select 2 indices to set to 0
        row[zero_indices] = 0  # Set the selected variables to 0

    # Generate random objective values between 0 and 1
    objective_values = np.random.rand(num_points, 1)

    return scaled_samples, objective_values
    
def get_next_points_EI_with_fixed_features(init_x, init_y, best_init_y, bounds, n_points, fixed_features_list, check_fixed_features_list):
    """
    Performs optimization by testing multiple fixed_features with `optimize_acqf`.
    Removes invalid fixed_features.

    Args:
        init_x (torch.Tensor): Initial input points.
        init_y (torch.Tensor): Initial objective values.
        best_init_y (float): Current best objective.
        bounds (torch.Tensor): Optimization bounds.
        n_points (int): Number of points to generate.
        fixed_features_list (list): List of fixed features to test.
        check_fixed_features_list (bool): If True, verifies fixed features.

    Returns:
        torch.Tensor: The best valid candidate found.
    """
    # Initialize the GP model
    gp = SingleTaskGP(init_x, init_y)
    mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
    fit_gpytorch_mll(mll)

    EI = qExpectedImprovement(model=gp, best_f=best_init_y)

    # Define constraints
    indices_TA = torch.tensor([0, 1, 2], dtype=torch.long)  # TA1, TA2, TA3
    coefficients_TA = torch.tensor([1.0, 1.0, 1.0], dtype=torch.float64)

    indices_S = torch.tensor([3, 4], dtype=torch.long)  # S1, S2
    coefficients_S = torch.tensor([1.0, 1.0], dtype=torch.float64)

    indices_C = torch.tensor([5, 6], dtype=torch.long)  # C1, C2
    coefficients_C = torch.tensor([1 / 0.912, 1 / 0.555], dtype=torch.float64)

    # Reformulate constraints for `optimize_acqf`
    inequality_constraints = [
        (indices_TA, coefficients_TA, 5.4),
        (indices_TA, -coefficients_TA, -10.6),
        (indices_S, coefficients_S, 15.4),
        (indices_S, -coefficients_S, -20.1),
        (indices_C, coefficients_C, 7.4),
        (indices_C, -coefficients_C, -9.6),
    ]

    candidates_list = []
    acq_values_list = []

    # Loop over fixed features
    for fixed_features in fixed_features_list[:]:
        try:
            candidates, acq_value = optimize_acqf(
                acq_function=EI,
                bounds=bounds,
                q=n_points,
                num_restarts=10,
                raw_samples=100,
                options={"batch_limit": 64, "maxiter": 200, "method": "SLSQP"},
                fixed_features=fixed_features,
                inequality_constraints=inequality_constraints,
            )

            print("Fixed feature: ", fixed_features)
            candidate = candidates[0]

            if check_fixed_features_list:
                # Verify constraints
                if check_constraints(candidate, print_bool=True):
                    candidates_list.append(candidate)
                    acq_values_list.append(acq_value)
                else:
                    print(f"❌ Candidate rejected due to constraints after checking: {candidate}")
                    # print(f"Failed to generate a candidate for fixed_features: {fixed_features}")
                    # print(f"❌ Removing fixed_features: {fixed_features}")
                    fixed_features_list.remove(fixed_features)
                    print("Len fixed_features_list: ", len(fixed_features_list))
            else:
                candidates_list.append(candidate)
                acq_values_list.append(acq_value)

        except CandidateGenerationError:
            print(f"CandidateGenerationError: Failed to generate a candidate for fixed_features: {fixed_features}")
            print(f"❌ Removing fixed_features: {fixed_features}")
            fixed_features_list.remove(fixed_features)  # Remove on failure
            print("Len fixed_features_list: ", len(fixed_features_list))

            continue  # Skip to the next fixed feature

    # Stop if no valid candidates were found
    if len(candidates_list) == 0:
        raise CandidateGenerationError("All fixed_features failed to generate a valid candidate.")

    # Return the best candidate found
    best_candidate_index = torch.argmax(torch.tensor(acq_values_list))
    return candidates_list[best_candidate_index].reshape(1, -1)

def check_constraints(candidate, tol=1e-6, print_bool=False):
    """
    Checks if a candidate satisfies the constraints.

    Args:
        candidate (torch.Tensor): A generated candidate.
        tol (float): Tolerance to handle rounding errors.

    Returns:
        bool: True if the candidate satisfies all constraints, otherwise False.
    """
    TA_sum = candidate[0:3].sum().item()  # TA1 + TA2 + TA3
    S_sum = candidate[3:5].sum().item()   # S1 + S2
    C_sum = (candidate[5] / 0.912 + candidate[6] / 0.555).item()  # C1/0.912 + C2/0.555

    # Add tolerance for comparisons
    valid = (
        (5.4 - tol <= TA_sum <= 10.6 + tol) and
        (15.4 - tol <= S_sum <= 20.1 + tol) and
        (7.4 - tol <= C_sum <= 9.6 + tol)
    )

    if print_bool:
        # Print results
        print(f"TA Sum: {TA_sum:.6f} (Constraint: 5.4 <= sum <= 10.6)")
        print(f"S Sum: {S_sum:.6f} (Constraint: 15.4 <= sum <= 20.1)")
        print(f"C Sum: {C_sum:.6f} (Constraint: 7.4 <= sum <= 9.6)")
        print(f"TA Valid: {5.4 - tol <= TA_sum <= 10.6 + tol}")
        print(f"S Valid: {15.4 - tol <= S_sum <= 20.1 + tol}")
        print(f"C Valid: {7.4 - tol <= C_sum <= 9.6 + tol}\n")

    return valid

# Define the search space and constraints
torch.manual_seed(42)

space = np.array([[ 0.   ,  5.5  ],
                  [ 0.   ,  5.5  ],
                  [ 0.   ,  5.5  ],
                  [14.5  , 19.   ],
                  [ 0.   ,  3.5  ],
                  [ 0.   ,  7.5  ],
                  [ 0.   ,  4.7  ],
                  [ 8.795, 16.77 ]])

num_points = len(space) + 1
seed = 42
max_iter = 50

# Generate LHS points
initial_points_EI, initial_responses_EI = generate_lhs_with_constraints(space, num_points, seed)

init_x_EI = torch.tensor(initial_points_EI, dtype=torch.float64)
init_y_EI = torch.tensor(initial_responses_EI, dtype=torch.float64)

bounds_EI = torch.tensor(space.T, dtype=torch.float64)

# Define fixed features
fixed_features_list = [{i: 0.0, j: 0.0} for i in range(5) for j in range(i + 1, 5)]
fixed_features_list = [{2: 0.0, 3: 0.0}]

best_init_y_EI = init_y_EI.max().item()

# Optimization
responses_over_time_EI = []
objective_values_EI = []
candidates_EI = []
candidates_prediction_EI = []

for i in range(max_iter):
    print(f"Iteration {i + 1}")

    try:
        new_candidate_EI = get_next_points_EI_with_fixed_features(
            init_x_EI,
            init_y_EI,
            best_init_y_EI,
            bounds_EI,
            n_points=1,
            fixed_features_list=fixed_features_list,
            check_fixed_features_list=(i == 0)
        )

        if new_candidate_EI is None:
            print("No valid candidates found. Stopping optimization.")
            break

        check_constraints(new_candidate_EI[0], print_bool=True)

        # Verification and storage
        new_result_EI, predicted_values_EI = target_function(new_candidate_EI)
        new_result_EI = new_result_EI.unsqueeze(-1)

        candidates_EI.append(new_candidate_EI)
        candidates_prediction_EI.append(predicted_values_EI)
        responses_over_time_EI.append(new_result_EI[0].item())
        objective_values_EI.append(-new_result_EI[0].item())

        init_x_EI = torch.cat([init_x_EI, new_candidate_EI])
        init_y_EI = torch.cat([init_y_EI, new_result_EI])

        best_init_y_EI, best_index_EI = init_y_EI.max(0)
        best_point_EI = init_x_EI[best_index_EI]

        print(f"New Candidate: {new_candidate_EI}")
        print(f"Best Value So Far: {best_init_y_EI.item()}")

    except CandidateGenerationError:
        print("All fixed_features failed. Optimization stopped.")
        break  # Stop the loop

# Final results
print("\nFinal Results:")
print("Best Point:", init_x_EI[init_y_EI.argmax()])
print("Best Value:", init_y_EI.max().item())
print("Best Iteration:", np.argmax(responses_over_time_EI) + 1)

@Balandat
Copy link
Contributor

Balandat commented Feb 2, 2025

@ysellai unfortunately your code is still not fully runnable - the lhs function in this block below is not defined:

# Generate initial LHS samples
lhs_samples = lhs(n_dims, samples=num_points)
scaled_samples = np.zeros_like(lhs_samples)

👉 If Scipy detects that a fixed feature does not satisfy the constraint, then it should not return any point at all.
This would prevent me from having to manually check each generated point to ensure that it respects the constraint.
🔹 Another concern:
Even if a fixed feature is compatible with a constraint, I cannot be certain that BoTorch always generates a valid point. I need a solution where BoTorch never returns a point that violates constraints, regardless of whether the fixed feature is feasible or not.

This request generally makes sense to me. I don't think it'll be easy to make any changes on the scipy side here, but we could incorporate the validation on the botorch side. It will be easier to investigate and pinpoint this if you share a fully runnable repro.

@ysellai
Copy link
Author

ysellai commented Feb 3, 2025

My bad I forgot to include the import for lhs in that code. To make it fully runnable, simply add this import at the beginning:

from pyDOE import lhs # Import for Latin Hypercube Sampling

This should resolve the issue with lhs_samples = lhs(n_dims, samples=num_points).

I am working with pyDOE==0.3.8

Also, do you think that using AX, for the permutation constraints I simulate with fixed_feature (e.g. one or 2 null variables each time), could work with constraints? I've tried to use non-linear constraints with botorch but the batch initialization makes it too complex.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants