From 11c14cde6000ac3174b0cce12164440a829480b6 Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Sat, 28 Dec 2024 11:23:05 +0100 Subject: [PATCH] [CP-SAT] revamp python implementation: introduce proper FloatLinearExpr class, move most of the expressions classes to C++ --- ortools/sat/circuit.cc | 9 +- ortools/sat/circuit.h | 2 +- ortools/sat/python/BUILD.bazel | 16 + ortools/sat/python/cp_model.py | 1198 ++++-------------------- ortools/sat/python/cp_model_test.py | 44 +- ortools/sat/python/linear_expr.cc | 609 ++++++++++++ ortools/sat/python/linear_expr.h | 581 ++++++++++++ ortools/sat/python/swig_helper.cc | 724 +++++++++++++- ortools/sat/python/swig_helper_test.py | 131 ++- ortools/sat/swig_helper.cc | 3 +- ortools/sat/swig_helper.h | 2 +- 11 files changed, 2210 insertions(+), 1109 deletions(-) create mode 100644 ortools/sat/python/linear_expr.cc create mode 100644 ortools/sat/python/linear_expr.h diff --git a/ortools/sat/circuit.cc b/ortools/sat/circuit.cc index 92ba0a8f4c4..f7fdaa9862c 100644 --- a/ortools/sat/circuit.cc +++ b/ortools/sat/circuit.cc @@ -698,9 +698,12 @@ void LoadSubcircuitConstraint(int num_nodes, const std::vector& tails, std::function CircuitCovering( absl::Span> graph, - const std::vector& distinguished_nodes) { - return [=, graph = std::vector>( - graph.begin(), graph.end())](Model* model) { + absl::Span distinguished_nodes) { + return [=, + distinguished_nodes = std::vector(distinguished_nodes.begin(), + distinguished_nodes.end()), + graph = std::vector>( + graph.begin(), graph.end())](Model* model) { CircuitCoveringPropagator* constraint = new CircuitCoveringPropagator(graph, distinguished_nodes, model); constraint->RegisterWith(model->GetOrCreate()); diff --git a/ortools/sat/circuit.h b/ortools/sat/circuit.h index 3c74e70a67a..548c4a33e4a 100644 --- a/ortools/sat/circuit.h +++ b/ortools/sat/circuit.h @@ -255,7 +255,7 @@ std::function ExactlyOnePerRowAndPerColumn( absl::Span> graph); std::function CircuitCovering( absl::Span> graph, - const std::vector& distinguished_nodes); + absl::Span distinguished_nodes); } // namespace sat } // namespace operations_research diff --git a/ortools/sat/python/BUILD.bazel b/ortools/sat/python/BUILD.bazel index e00ca4e418f..72d07918f89 100644 --- a/ortools/sat/python/BUILD.bazel +++ b/ortools/sat/python/BUILD.bazel @@ -17,11 +17,27 @@ load("@pip_deps//:requirements.bzl", "requirement") load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") load("@rules_python//python:defs.bzl", "py_library", "py_test") +cc_library( + name = "linear_expr", + srcs = ["linear_expr.cc"], + hdrs = ["linear_expr.h"], + deps = [ + "//ortools/sat:cp_model_cc_proto", + "//ortools/util:sorted_interval_list", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:fixed_array", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + ], +) + pybind_extension( name = "swig_helper", srcs = ["swig_helper.cc"], visibility = ["//visibility:public"], deps = [ + ":linear_expr", "//ortools/sat:cp_model_cc_proto", "//ortools/sat:sat_parameters_cc_proto", "//ortools/sat:swig_helper", diff --git a/ortools/sat/python/cp_model.py b/ortools/sat/python/cp_model.py index d2b7578051c..1efe750d37d 100644 --- a/ortools/sat/python/cp_model.py +++ b/ortools/sat/python/cp_model.py @@ -45,8 +45,6 @@ rather than for solving specific optimization problems. """ -import collections -import itertools import threading import time from typing import ( @@ -54,8 +52,6 @@ Callable, Dict, Iterable, - List, - NoReturn, Optional, Sequence, Tuple, @@ -74,7 +70,12 @@ from ortools.sat.python import swig_helper from ortools.util.python import sorted_interval_list +# Import external types. Domain = sorted_interval_list.Domain +LinearExpr = swig_helper.LinearExpr +FloatLinearExpr = swig_helper.FloatLinearExpr +BoundedLinearExpression = swig_helper.BoundedLinearExpression + # The classes below allow linear expressions to be expressed naturally with the # usual arithmetic operators + - * / and with constant numbers, which makes the @@ -152,13 +153,13 @@ np.double, ) -LiteralT = Union["IntVar", "_NotBooleanVariable", IntegralT, bool] -BoolVarT = Union["IntVar", "_NotBooleanVariable"] +LiteralT = Union[swig_helper.Literal, IntegralT, bool] +BoolVarT = swig_helper.Literal VariableT = Union["IntVar", IntegralT] # We need to add 'IntVar' for pytype. -LinearExprT = Union["LinearExpr", "IntVar", IntegralT] -ObjLinearExprT = Union["LinearExpr", NumberT] +LinearExprT = Union[LinearExpr, "IntVar", IntegralT] +ObjLinearExprT = Union[FloatLinearExpr, NumberT] ArcT = Tuple[IntegralT, IntegralT, LiteralT] _IndexOrSeries = Union[pd.Index, pd.Series] @@ -180,14 +181,14 @@ def display_bounds(bounds: Sequence[int]) -> str: def short_name(model: cp_model_pb2.CpModelProto, i: int) -> str: """Returns a short name of an integer variable, or its negation.""" if i < 0: - return "not(%s)" % short_name(model, -i - 1) + return f"not({short_name(model, -i - 1)})" v = model.variables[i] if v.name: return v.name elif len(v.domain) == 2 and v.domain[0] == v.domain[1]: return str(v.domain[0]) else: - return "[%s]" % display_bounds(v.domain) + return f"[{display_bounds(v.domain)}]" def short_expr_name( @@ -215,631 +216,7 @@ def short_expr_name( return str(e) -class LinearExpr: - """Holds an integer linear expression. - - A linear expression is built from integer constants and variables. - For example, `x + 2 * (y - z + 1)`. - - Linear expressions are used in CP-SAT models in constraints and in the - objective: - - * You can define linear constraints as in: - - ``` - model.add(x + 2 * y <= 5) - model.add(sum(array_of_vars) == 5) - ``` - - * In CP-SAT, the objective is a linear expression: - - ``` - model.minimize(x + 2 * y + z) - ``` - - * For large arrays, using the LinearExpr class is faster that using the python - `sum()` function. You can create constraints and the objective from lists of - linear expressions or coefficients as follows: - - ``` - model.minimize(cp_model.LinearExpr.sum(expressions)) - model.add(cp_model.LinearExpr.weighted_sum(expressions, coefficients) >= 0) - ``` - """ - - @classmethod - def sum(cls, expressions: Sequence[LinearExprT]) -> LinearExprT: - """Creates the expression sum(expressions).""" - if len(expressions) == 1: - return expressions[0] - return _SumArray(expressions) - - @overload - @classmethod - def weighted_sum( - cls, - expressions: Sequence[LinearExprT], - coefficients: Sequence[IntegralT], - ) -> LinearExprT: ... - - @overload - @classmethod - def weighted_sum( - cls, - expressions: Sequence[ObjLinearExprT], - coefficients: Sequence[NumberT], - ) -> ObjLinearExprT: ... - - @classmethod - def weighted_sum(cls, expressions, coefficients): - """Creates the expression sum(expressions[i] * coefficients[i]).""" - if LinearExpr.is_empty_or_all_null(coefficients): - return 0 - elif len(expressions) == 1: - return expressions[0] * coefficients[0] - else: - return _WeightedSum(expressions, coefficients) - - @overload - @classmethod - def term( - cls, - expressions: LinearExprT, - coefficients: IntegralT, - ) -> LinearExprT: ... - - @overload - @classmethod - def term( - cls, - expressions: ObjLinearExprT, - coefficients: NumberT, - ) -> ObjLinearExprT: ... - - @classmethod - def term(cls, expression, coefficient): - """Creates `expression * coefficient`.""" - if cmh.is_zero(coefficient): - return 0 - else: - return expression * coefficient - - @classmethod - def is_empty_or_all_null(cls, coefficients: Sequence[NumberT]) -> bool: - for c in coefficients: - if not cmh.is_zero(c): - return False - return True - - @classmethod - def rebuild_from_linear_expression_proto( - cls, - model: cp_model_pb2.CpModelProto, - proto: cp_model_pb2.LinearExpressionProto, - ) -> LinearExprT: - """Recreate a LinearExpr from a LinearExpressionProto.""" - offset = proto.offset - num_elements = len(proto.vars) - if num_elements == 0: - return offset - elif num_elements == 1: - return ( - IntVar(model, proto.vars[0], None) * proto.coeffs[0] + offset - ) # pytype: disable=bad-return-type - else: - variables = [] - coeffs = [] - all_ones = True - for index, coeff in zip(proto.vars, proto.coeffs): - variables.append(IntVar(model, index, None)) - coeffs.append(coeff) - if not cmh.is_one(coeff): - all_ones = False - if all_ones: - return _SumArray(variables, offset) - else: - return _WeightedSum(variables, coeffs, offset) - - def get_integer_var_value_map(self) -> Tuple[Dict["IntVar", int], int]: - """Scans the expression, and returns (var_coef_map, constant).""" - coeffs: Dict["IntVar", int] = collections.defaultdict(int) - constant = 0 - to_process: List[Tuple[LinearExprT, int]] = [(self, 1)] - while to_process: # Flatten to avoid recursion. - expr: LinearExprT - coeff: int - expr, coeff = to_process.pop() - if isinstance(expr, IntegralTypes): - constant += coeff * int(expr) - elif isinstance(expr, _ProductCst): - to_process.append((expr.expression(), coeff * expr.coefficient())) - elif isinstance(expr, _Sum): - to_process.append((expr.left(), coeff)) - to_process.append((expr.right(), coeff)) - elif isinstance(expr, _SumArray): - for e in expr.expressions(): - to_process.append((e, coeff)) - constant += expr.constant() * coeff - elif isinstance(expr, _WeightedSum): - for e, c in zip(expr.expressions(), expr.coefficients()): - to_process.append((e, coeff * c)) - constant += expr.constant() * coeff - elif isinstance(expr, IntVar): - coeffs[expr] += coeff - elif isinstance(expr, _NotBooleanVariable): - constant += coeff - coeffs[expr.negated()] -= coeff - elif isinstance(expr, NumberTypes): - raise TypeError( - f"Floating point constants are not supported in constraints: {expr}" - ) - else: - raise TypeError("Unrecognized linear expression: " + str(expr)) - - return coeffs, constant - - def get_float_var_value_map( - self, - ) -> Tuple[Dict["IntVar", float], float, bool]: - """Scans the expression. Returns (var_coef_map, constant, is_integer).""" - coeffs: Dict["IntVar", Union[int, float]] = {} - constant: Union[int, float] = 0 - to_process: List[Tuple[LinearExprT, Union[int, float]]] = [(self, 1)] - while to_process: # Flatten to avoid recursion. - expr, coeff = to_process.pop() - if isinstance(expr, IntegralTypes): # Keep integrality. - constant += coeff * int(expr) - elif isinstance(expr, NumberTypes): - constant += coeff * float(expr) - elif isinstance(expr, _ProductCst): - to_process.append((expr.expression(), coeff * expr.coefficient())) - elif isinstance(expr, _Sum): - to_process.append((expr.left(), coeff)) - to_process.append((expr.right(), coeff)) - elif isinstance(expr, _SumArray): - for e in expr.expressions(): - to_process.append((e, coeff)) - constant += expr.constant() * coeff - elif isinstance(expr, _WeightedSum): - for e, c in zip(expr.expressions(), expr.coefficients()): - to_process.append((e, coeff * c)) - constant += expr.constant() * coeff - elif isinstance(expr, IntVar): - if expr in coeffs: - coeffs[expr] += coeff - else: - coeffs[expr] = coeff - elif isinstance(expr, _NotBooleanVariable): - constant += coeff - if expr.negated() in coeffs: - coeffs[expr.negated()] -= coeff - else: - coeffs[expr.negated()] = -coeff - else: - raise TypeError("Unrecognized linear expression: " + str(expr)) - is_integer = isinstance(constant, IntegralTypes) - if is_integer: - for coeff in coeffs.values(): - if not isinstance(coeff, IntegralTypes): - is_integer = False - break - return coeffs, constant, is_integer - - def __hash__(self) -> int: - return object.__hash__(self) - - def __abs__(self) -> NoReturn: - raise NotImplementedError( - "calling abs() on a linear expression is not supported, " - "please use CpModel.add_abs_equality" - ) - - @overload - def __add__(self, arg: "LinearExpr") -> "LinearExpr": ... - - @overload - def __add__(self, arg: NumberT) -> "LinearExpr": ... - - def __add__(self, arg): - if cmh.is_zero(arg): - return self - return _Sum(self, arg) - - @overload - def __radd__(self, arg: "LinearExpr") -> "LinearExpr": ... - - @overload - def __radd__(self, arg: NumberT) -> "LinearExpr": ... - - def __radd__(self, arg): - return self.__add__(arg) - - @overload - def __sub__(self, arg: "LinearExpr") -> "LinearExpr": ... - - @overload - def __sub__(self, arg: NumberT) -> "LinearExpr": ... - - def __sub__(self, arg): - if cmh.is_zero(arg): - return self - if isinstance(arg, NumberTypes): - return _Sum(self, -arg) - else: - return _Sum(self, -arg) - - @overload - def __rsub__(self, arg: "LinearExpr") -> "LinearExpr": ... - - @overload - def __rsub__(self, arg: NumberT) -> "LinearExpr": ... - - def __rsub__(self, arg): - return _Sum(-self, arg) - - @overload - def __mul__(self, arg: IntegralT) -> Union["LinearExpr", IntegralT]: ... - - @overload - def __mul__(self, arg: NumberT) -> Union["LinearExpr", NumberT]: ... - - def __mul__(self, arg): - arg = cmh.assert_is_a_number(arg) - if cmh.is_one(arg): - return self - elif cmh.is_zero(arg): - return 0 - return _ProductCst(self, arg) - - @overload - def __rmul__(self, arg: IntegralT) -> Union["LinearExpr", IntegralT]: ... - - @overload - def __rmul__(self, arg: NumberT) -> Union["LinearExpr", NumberT]: ... - - def __rmul__(self, arg): - return self.__mul__(arg) - - def __div__(self, _) -> NoReturn: - raise NotImplementedError( - "calling / on a linear expression is not supported, " - "please use CpModel.add_division_equality" - ) - - def __truediv__(self, _) -> NoReturn: - raise NotImplementedError( - "calling // on a linear expression is not supported, " - "please use CpModel.add_division_equality" - ) - - def __mod__(self, _) -> NoReturn: - raise NotImplementedError( - "calling %% on a linear expression is not supported, " - "please use CpModel.add_modulo_equality" - ) - - def __pow__(self, _) -> NoReturn: - raise NotImplementedError( - "calling ** on a linear expression is not supported, " - "please use CpModel.add_multiplication_equality" - ) - - def __lshift__(self, _) -> NoReturn: - raise NotImplementedError( - "calling left shift on a linear expression is not supported" - ) - - def __rshift__(self, _) -> NoReturn: - raise NotImplementedError( - "calling right shift on a linear expression is not supported" - ) - - def __and__(self, _) -> NoReturn: - raise NotImplementedError( - "calling and on a linear expression is not supported, " - "please use CpModel.add_bool_and" - ) - - def __or__(self, _) -> NoReturn: - raise NotImplementedError( - "calling or on a linear expression is not supported, " - "please use CpModel.add_bool_or" - ) - - def __xor__(self, _) -> NoReturn: - raise NotImplementedError( - "calling xor on a linear expression is not supported, " - "please use CpModel.add_bool_xor" - ) - - def __neg__(self) -> "LinearExpr": - return _ProductCst(self, -1) - - def __bool__(self) -> NoReturn: - raise NotImplementedError( - "Evaluating a LinearExpr instance as a Boolean is not implemented." - ) - - @overload - def __eq__(self, arg: IntegralT) -> "BoundedLinearExpression": ... - - @overload - def __eq__(self, arg: "LinearExpr") -> "BoundedLinearExpression": ... - - @overload - def __eq__(self, arg: Any) -> bool: ... - - # pytype: disable=bad-return-type - def __eq__(self, arg): - if isinstance(arg, IntegralTypes): - return BoundedLinearExpression(self, [arg, arg]) - if isinstance(arg, LinearExpr): - return BoundedLinearExpression(self - arg, [0, 0]) - return False - - # pytype: enable=bad-return-type - - def __ge__(self, arg: LinearExprT) -> "BoundedLinearExpression": - if isinstance(arg, IntegralTypes): - if arg >= INT_MAX: - raise ArithmeticError(">= INT_MAX is not supported") - return BoundedLinearExpression(self, [arg, INT_MAX]) - else: - return BoundedLinearExpression(self - arg, [0, INT_MAX]) - - def __le__(self, arg: LinearExprT) -> "BoundedLinearExpression": - if isinstance(arg, IntegralTypes): - if arg <= INT_MIN: - raise ArithmeticError("<= INT_MIN is not supported") - return BoundedLinearExpression(self, [INT_MIN, arg]) - else: - return BoundedLinearExpression(self - arg, [INT_MIN, 0]) - - def __lt__(self, arg: LinearExprT) -> "BoundedLinearExpression": - if isinstance(arg, IntegralTypes): - if arg <= INT_MIN: - raise ArithmeticError("< INT_MIN is not supported") - return BoundedLinearExpression(self, [INT_MIN, arg - 1]) - else: - return BoundedLinearExpression(self - arg, [INT_MIN, -1]) - - def __gt__(self, arg: LinearExprT) -> "BoundedLinearExpression": - if isinstance(arg, IntegralTypes): - if arg >= INT_MAX: - raise ArithmeticError("> INT_MAX is not supported") - return BoundedLinearExpression(self, [arg + 1, INT_MAX]) - else: - return BoundedLinearExpression(self - arg, [1, INT_MAX]) - - @overload - def __ne__(self, arg: "LinearExpr") -> "BoundedLinearExpression": ... - - @overload - def __ne__(self, arg: IntegralT) -> "BoundedLinearExpression": ... - - @overload - def __ne__(self, arg: Any) -> bool: ... - - # pytype: disable=bad-return-type - def __ne__(self, arg): - if isinstance(arg, IntegralTypes): - if arg >= INT_MAX: - return BoundedLinearExpression(self, [INT_MIN, INT_MAX - 1]) - elif arg <= INT_MIN: - return BoundedLinearExpression(self, [INT_MIN + 1, INT_MAX]) - else: - return BoundedLinearExpression( - self, [INT_MIN, arg - 1, arg + 1, INT_MAX] - ) - elif isinstance(arg, LinearExpr): - return BoundedLinearExpression(self - arg, [INT_MIN, -1, 1, INT_MAX]) - return True - - # pytype: enable=bad-return-type - - # Compatibility with pre PEP8 - # pylint: disable=invalid-name - @classmethod - def Sum(cls, expressions: Sequence[LinearExprT]) -> LinearExprT: - """Creates the expression sum(expressions).""" - return cls.sum(expressions) - - @overload - @classmethod - def WeightedSum( - cls, - expressions: Sequence[LinearExprT], - coefficients: Sequence[IntegralT], - ) -> LinearExprT: ... - - @overload - @classmethod - def WeightedSum( - cls, - expressions: Sequence[ObjLinearExprT], - coefficients: Sequence[NumberT], - ) -> ObjLinearExprT: ... - - @classmethod - def WeightedSum(cls, expressions, coefficients): - """Creates the expression sum(expressions[i] * coefficients[i]).""" - return cls.weighted_sum(expressions, coefficients) - - @overload - @classmethod - def Term( - cls, - expressions: LinearExprT, - coefficients: IntegralT, - ) -> LinearExprT: ... - - @overload - @classmethod - def Term( - cls, - expressions: ObjLinearExprT, - coefficients: NumberT, - ) -> ObjLinearExprT: ... - - @classmethod - def Term(cls, expression, coefficient): - """Creates `expression * coefficient`.""" - return cls.term(expression, coefficient) - - # pylint: enable=invalid-name - - -class _Sum(LinearExpr): - """Represents the sum of two LinearExprs.""" - - def __init__(self, left, right) -> None: - for x in [left, right]: - if not isinstance(x, (NumberTypes, LinearExpr)): - raise TypeError("not an linear expression: " + str(x)) - self.__left = left - self.__right = right - - def left(self): - return self.__left - - def right(self): - return self.__right - - def __str__(self): - return f"({self.__left} + {self.__right})" - - def __repr__(self): - return f"sum({self.__left!r}, {self.__right!r})" - - -class _ProductCst(LinearExpr): - """Represents the product of a LinearExpr by a constant.""" - - def __init__(self, expr, coeff) -> None: - if isinstance(expr, _ProductCst): - self.__expr = expr.expression() - self.__coef = expr.coefficient() * coeff - else: - self.__expr = expr - self.__coef = coeff - - def __str__(self): - if self.__coef == -1: - return "-" + str(self.__expr) - else: - return "(" + str(self.__coef) + " * " + str(self.__expr) + ")" - - def __repr__(self): - return f"ProductCst({self.__expr!r}, {self.__coef!r})" - - def coefficient(self): - return self.__coef - - def expression(self): - return self.__expr - - -class _SumArray(LinearExpr): - """Represents the sum of a list of LinearExpr and a constant.""" - - def __init__(self, expressions, constant=0) -> None: - self.__expressions = [] - self.__constant = constant - for x in expressions: - if isinstance(x, NumberTypes): - if cmh.is_zero(x): - continue - self.__constant += x - elif isinstance(x, LinearExpr): - self.__expressions.append(x) - else: - raise TypeError("not an linear expression: " + str(x)) - - def __str__(self): - constant_terms = (self.__constant,) if self.__constant != 0 else () - exprs_str = " + ".join( - map(repr, itertools.chain(self.__expressions, constant_terms)) - ) - if not exprs_str: - return "0" - return f"({exprs_str})" - - def __repr__(self): - exprs_str = ", ".join(map(repr, self.__expressions)) - return f"SumArray({exprs_str}, {self.__constant})" - - def expressions(self): - return self.__expressions - - def constant(self): - return self.__constant - - -class _WeightedSum(LinearExpr): - """Represents sum(ai * xi) + b.""" - - def __init__(self, expressions, coefficients, constant=0) -> None: - self.__expressions = [] - self.__coefficients = [] - self.__constant = constant - if len(expressions) != len(coefficients): - raise TypeError( - "In the LinearExpr.weighted_sum method, the expression array and the " - " coefficient array must have the same length." - ) - for e, c in zip(expressions, coefficients): - if cmh.is_zero(c): - continue - if isinstance(e, NumberTypes): - self.__constant += e * c - elif isinstance(e, LinearExpr): - self.__expressions.append(e) - self.__coefficients.append(c) - else: - raise TypeError("not an linear expression: " + str(e)) - - def __str__(self): - output = None - for expr, coeff in zip(self.__expressions, self.__coefficients): - if not output and cmh.is_one(coeff): - output = str(expr) - elif not output and cmh.is_minus_one(coeff): - output = "-" + str(expr) - elif not output: - output = f"{coeff} * {expr}" - elif cmh.is_one(coeff): - output += f" + {expr}" - elif cmh.is_minus_one(coeff): - output += f" - {expr}" - elif coeff > 1: - output += f" + {coeff} * {expr}" - elif coeff < -1: - output += f" - {-coeff} * {expr}" - if output is None: - output = str(self.__constant) - elif self.__constant > 0: - output += f" + {self.__constant}" - elif self.__constant < 0: - output += f" - {-self.__constant}" - return output - - def __repr__(self): - return ( - f"weighted_sum({self.__expressions!r}, {self.__coefficients!r}," - f" {self.__constant})" - ) - - def expressions(self): - return self.__expressions - - def coefficients(self): - return self.__coefficients - - def constant(self): - return self.__constant - - -class IntVar(LinearExpr): +class IntVar(swig_helper.BaseIntVar): """An integer variable. An IntVar is an object that can take on any integer value within defined @@ -857,12 +234,11 @@ def __init__( self, model: cp_model_pb2.CpModelProto, domain: Union[int, sorted_interval_list.Domain], + is_boolean: bool, name: Optional[str], ) -> None: """See CpModel.new_int_var below.""" - self.__index: int self.__var: cp_model_pb2.IntegerVariableProto - self.__negation: Optional[_NotBooleanVariable] = None # Python do not support multiple __init__ methods. # This method is only called from the CpModel class. # We hack the parameter to support the two cases: @@ -871,10 +247,10 @@ def __init__( # case 2: # model is a CpModelProto, domain is an index (int), and name is None. if isinstance(domain, IntegralTypes) and name is None: - self.__index = int(domain) + swig_helper.BaseIntVar.__init__(self, int(domain), is_boolean) self.__var = model.variables[domain] else: - self.__index = len(model.variables) + swig_helper.BaseIntVar.__init__(self, len(model.variables), is_boolean) self.__var = model.variables.add() self.__var.domain.extend( cast(sorted_interval_list.Domain, domain).flattened_intervals() @@ -882,11 +258,6 @@ def __init__( if name is not None: self.__var.name = name - @property - def index(self) -> int: - """Returns the index of the variable in the model.""" - return self.__index - @property def proto(self) -> cp_model_pb2.IntegerVariableProto: """Returns the variable protobuf.""" @@ -906,12 +277,15 @@ def __str__(self) -> str: ): # Special case for constants. return str(self.__var.domain[0]) + elif self.is_boolean: + return f"BooleanVar({self.__index})" else: - return "unnamed_var_%i" % self.__index - return self.__var.name + return f"IntVar({self.__index})" + else: + return self.__var.name def __repr__(self) -> str: - return "%s(%s)" % (self.__var.name, display_bounds(self.__var.domain)) + return f"{self}({display_bounds(self.__var.domain)})" @property def name(self) -> str: @@ -919,155 +293,34 @@ def name(self) -> str: return "" return self.__var.name - def negated(self) -> "_NotBooleanVariable": - """Returns the negation of a Boolean variable. - - This method implements the logical negation of a Boolean variable. - It is only valid if the variable has a Boolean domain (0 or 1). - - Note that this method is nilpotent: `x.negated().negated() == x`. - """ - - for bound in self.__var.domain: - if bound < 0 or bound > 1: - raise TypeError( - f"cannot call negated on a non boolean variable: {self}" - ) - if self.__negation is None: - self.__negation = _NotBooleanVariable(self) - return self.__negation - - def __invert__(self) -> "_NotBooleanVariable": - """Returns the logical negation of a Boolean variable.""" - return self.negated() - # Pre PEP8 compatibility. # pylint: disable=invalid-name - Not = negated - def Name(self) -> str: return self.name def Proto(self) -> cp_model_pb2.IntegerVariableProto: return self.proto - def Index(self) -> int: - return self.index - # pylint: enable=invalid-name -class _NotBooleanVariable(LinearExpr): - """Negation of a boolean variable.""" - - def __init__(self, boolvar: IntVar) -> None: - self.__boolvar: IntVar = boolvar - - @property - def index(self) -> int: - return -self.__boolvar.index - 1 - - def negated(self) -> IntVar: - return self.__boolvar - - def __invert__(self) -> IntVar: - """Returns the logical negation of a Boolean literal.""" - return self.negated() - - def __str__(self) -> str: - return self.name - - @property - def name(self) -> str: - return "not(%s)" % str(self.__boolvar) - - def __bool__(self) -> NoReturn: - raise NotImplementedError( - "Evaluating a literal as a Boolean value is not implemented." - ) - - # Pre PEP8 compatibility. - # pylint: disable=invalid-name - def Not(self) -> "IntVar": - return self.negated() - - def Index(self) -> int: - return self.index - - # pylint: enable=invalid-name - - -class BoundedLinearExpression: - """Represents a linear constraint: `lb <= linear expression <= ub`. - - The only use of this class is to be added to the CpModel through - `CpModel.add(expression)`, as in: - - model.add(x + 2 * y -1 >= z) - """ - - def __init__(self, expr: LinearExprT, bounds: Sequence[int]) -> None: - self.__expr: LinearExprT = expr - self.__bounds: Sequence[int] = bounds - - def __str__(self): - if len(self.__bounds) == 2: - lb, ub = self.__bounds - if lb > INT_MIN and ub < INT_MAX: - if lb == ub: - return str(self.__expr) + " == " + str(lb) - else: - return str(lb) + " <= " + str(self.__expr) + " <= " + str(ub) - elif lb > INT_MIN: - return str(self.__expr) + " >= " + str(lb) - elif ub < INT_MAX: - return str(self.__expr) + " <= " + str(ub) - else: - return "True (unbounded expr " + str(self.__expr) + ")" - elif ( - len(self.__bounds) == 4 - and self.__bounds[0] == INT_MIN - and self.__bounds[1] + 2 == self.__bounds[2] - and self.__bounds[3] == INT_MAX - ): - return str(self.__expr) + " != " + str(self.__bounds[1] + 1) - else: - return str(self.__expr) + " in [" + display_bounds(self.__bounds) + "]" - - def expression(self) -> LinearExprT: - return self.__expr - - def bounds(self) -> Sequence[int]: - return self.__bounds - - def __bool__(self) -> bool: - expr = self.__expr - if isinstance(expr, LinearExpr): - coeffs_map, constant = expr.get_integer_var_value_map() - all_coeffs = set(coeffs_map.values()) - same_var = set([0]) - eq_bounds = [0, 0] - different_vars = set([-1, 1]) - ne_bounds = [INT_MIN, -1, 1, INT_MAX] - if ( - len(coeffs_map) == 1 - and all_coeffs == same_var - and constant == 0 - and (self.__bounds == eq_bounds or self.__bounds == ne_bounds) - ): - return self.__bounds == eq_bounds - if ( - len(coeffs_map) == 2 - and all_coeffs == different_vars - and constant == 0 - and (self.__bounds == eq_bounds or self.__bounds == ne_bounds) - ): - return self.__bounds == ne_bounds - - raise NotImplementedError( - f'Evaluating a BoundedLinearExpression "{self}" as a Boolean value' - + " is not supported." - ) +def rebuild_from_linear_expression_proto( + model: cp_model_pb2.CpModelProto, + proto: cp_model_pb2.LinearExpressionProto, +) -> LinearExprT: + """Recreate a LinearExpr from a LinearExpressionProto.""" + num_elements = len(proto.vars) + if num_elements == 0: + return proto.offset + elif num_elements == 1: + return ( + IntVar(model, proto.vars[0], False, None) * proto.coeffs[0] + proto.offset + ) # pytype: disable=bad-return-type + else: + variables = [] + for index in proto.vars: + variables.append(IntVar(model, index, False, None)) + return LinearExpr.weighted_sum(variables, proto.coeffs, proto.offset) class Constraint: @@ -1132,7 +385,7 @@ def only_enforce_if(self, *boolvar) -> "Constraint": ) else: self.__constraint.enforcement_literal.append( - cast(Union[IntVar, _NotBooleanVariable], lit).index + cast(swig_helper.Literal, lit).index ) return self @@ -1282,17 +535,17 @@ def name(self) -> str: return self.__ct.name def start_expr(self) -> LinearExprT: - return LinearExpr.rebuild_from_linear_expression_proto( + return rebuild_from_linear_expression_proto( self.__model, self.__ct.interval.start ) def size_expr(self) -> LinearExprT: - return LinearExpr.rebuild_from_linear_expression_proto( + return rebuild_from_linear_expression_proto( self.__model, self.__ct.interval.size ) def end_expr(self) -> LinearExprT: - return LinearExpr.rebuild_from_linear_expression_proto( + return rebuild_from_linear_expression_proto( self.__model, self.__ct.interval.end ) @@ -1319,7 +572,7 @@ def object_is_a_true_literal(literal: LiteralT) -> bool: if isinstance(literal, IntVar): proto = literal.proto return len(proto.domain) == 2 and proto.domain[0] == 1 and proto.domain[1] == 1 - if isinstance(literal, _NotBooleanVariable): + if isinstance(literal, swig_helper.NotBooleanVariable): proto = literal.negated().proto return len(proto.domain) == 2 and proto.domain[0] == 0 and proto.domain[1] == 0 if isinstance(literal, IntegralTypes): @@ -1332,7 +585,7 @@ def object_is_a_false_literal(literal: LiteralT) -> bool: if isinstance(literal, IntVar): proto = literal.proto return len(proto.domain) == 2 and proto.domain[0] == 0 and proto.domain[1] == 0 - if isinstance(literal, _NotBooleanVariable): + if isinstance(literal, swig_helper.NotBooleanVariable): proto = literal.negated().proto return len(proto.domain) == 2 and proto.domain[0] == 1 and proto.domain[1] == 1 if isinstance(literal, IntegralTypes): @@ -1383,8 +636,13 @@ def new_int_var(self, lb: IntegralT, ub: IntegralT, name: str) -> IntVar: Returns: a variable whose domain is [lb, ub]. """ - - return IntVar(self.__model, sorted_interval_list.Domain(lb, ub), name) + domain_is_boolean = lb >= 0 and ub <= 1 + return IntVar( + self.__model, + sorted_interval_list.Domain(lb, ub), + domain_is_boolean, + name, + ) def new_int_var_from_domain( self, domain: sorted_interval_list.Domain, name: str @@ -1402,15 +660,22 @@ def new_int_var_from_domain( Returns: a variable whose domain is the given domain. """ - return IntVar(self.__model, domain, name) + domain_is_boolean = domain.min() >= 0 and domain.max() <= 1 + return IntVar(self.__model, domain, domain_is_boolean, name) def new_bool_var(self, name: str) -> IntVar: """Creates a 0-1 variable with the given name.""" - return IntVar(self.__model, sorted_interval_list.Domain(0, 1), name) + return IntVar(self.__model, sorted_interval_list.Domain(0, 1), True, name) def new_constant(self, value: IntegralT) -> IntVar: """Declares a constant integer.""" - return IntVar(self.__model, self.get_or_make_index_from_constant(value), None) + domain_is_boolean = value == 0 or value == 1 + return IntVar( + self.__model, + self.get_or_make_index_from_constant(value), + domain_is_boolean, + None, + ) def new_int_var_series( self, @@ -1471,6 +736,7 @@ def new_int_var_series( domain=sorted_interval_list.Domain( lower_bounds[i], upper_bounds[i] ), + is_boolean=lower_bounds[i] >= 0 and upper_bounds[i] <= 1, ) for i in index ], @@ -1494,8 +760,22 @@ def new_bool_var_series( TypeError: if the `index` is invalid (e.g. a `DataFrame`). ValueError: if the `name` is not a valid identifier or already exists. """ - return self.new_int_var_series( - name=name, index=index, lower_bounds=0, upper_bounds=1 + if not isinstance(index, pd.Index): + raise TypeError("Non-index object is used as index") + if not name.isidentifier(): + raise ValueError(f"name={name} is not a valid identifier") + return pd.Series( + index=index, + data=[ + # pylint: disable=g-complex-comprehension + IntVar( + model=self.__model, + name=f"{name}[{i}]", + domain=sorted_interval_list.Domain(0, 1), + is_boolean=True, + ) + for i in index + ], ) # Linear constraints. @@ -1509,25 +789,13 @@ def add_linear_constraint( ) def add_linear_expression_in_domain( - self, linear_expr: LinearExprT, domain: sorted_interval_list.Domain + self, + linear_expr: LinearExprT, + domain: sorted_interval_list.Domain, ) -> Constraint: """Adds the constraint: `linear_expr` in `domain`.""" if isinstance(linear_expr, LinearExpr): - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - coeffs_map, constant = linear_expr.get_integer_var_value_map() - for t in coeffs_map.items(): - if not isinstance(t[0], IntVar): - raise TypeError("Wrong argument" + str(t)) - model_ct.linear.vars.append(t[0].index) - model_ct.linear.coeffs.append(t[1]) - model_ct.linear.domain.extend( - [ - cmh.capped_subtraction(x, constant) - for x in domain.flattened_intervals() - ] - ) - return ct + return self.add(BoundedLinearExpression(linear_expr, domain)) if isinstance(linear_expr, IntegralTypes): if not domain.contains(int(linear_expr)): return self.add_bool_or([]) # Evaluate to false. @@ -1537,17 +805,17 @@ def add_linear_expression_in_domain( "not supported: CpModel.add_linear_expression_in_domain(" + str(linear_expr) + " " + + str(type(linear_expr)) + + " " + + str(linear_expr.is_integer()) + + " " + str(domain) + + " " + + str(type(domain)) + ")" ) - @overload - def add(self, ct: BoundedLinearExpression) -> Constraint: ... - - @overload - def add(self, ct: Union[bool, np.bool_]) -> Constraint: ... - - def add(self, ct): + def add(self, ct: Union[BoundedLinearExpression, bool, np.bool_]) -> Constraint: """Adds a `BoundedLinearExpression` to the model. Args: @@ -1555,12 +823,23 @@ def add(self, ct): Returns: An instance of the `Constraint` class. + + Raises: + TypeError: If the `ct` is not a `BoundedLinearExpression` or a Boolean. """ if isinstance(ct, BoundedLinearExpression): - return self.add_linear_expression_in_domain( - ct.expression(), - sorted_interval_list.Domain.from_flat_intervals(ct.bounds()), + result = Constraint(self) + model_ct = self.__model.constraints[result.index] + for var in ct.vars: + model_ct.linear.vars.append(var.index) + model_ct.linear.coeffs.extend(ct.coeffs) + model_ct.linear.domain.extend( + [ + cmh.capped_subtraction(x, ct.offset) + for x in ct.bounds.flattened_intervals() + ] ) + return result if ct and cmh.is_boolean(ct): return self.add_bool_or([True]) if not ct and cmh.is_boolean(ct): @@ -2741,7 +2020,7 @@ def get_bool_var_from_proto_index(self, index: int) -> IntVar: + " a Boolean variable" ) - return IntVar(self.__model, index, None) + return IntVar(self.__model, index, True, None) def get_int_var_from_proto_index(self, index: int) -> IntVar: """Returns an already created integer variable from its index.""" @@ -2749,7 +2028,7 @@ def get_int_var_from_proto_index(self, index: int) -> IntVar: raise ValueError( f"get_int_var_from_proto_index: out of bound index {index}" ) - return IntVar(self.__model, index, None) + return IntVar(self.__model, index, False, None) def get_interval_var_from_proto_index(self, index: int) -> IntervalVar: """Returns an already created interval variable from its index.""" @@ -2783,12 +2062,6 @@ def get_or_make_index(self, arg: VariableT) -> int: """Returns the index of a variable, its negation, or a number.""" if isinstance(arg, IntVar): return arg.index - if ( - isinstance(arg, _ProductCst) - and isinstance(arg.expression(), IntVar) - and arg.coefficient() == -1 - ): - return -arg.expression().index - 1 if isinstance(arg, IntegralTypes): return self.get_or_make_index_from_constant(arg) raise TypeError("NotSupported: model.get_or_make_index(" + str(arg) + ")") @@ -2798,7 +2071,7 @@ def get_or_make_boolean_index(self, arg: LiteralT) -> int: if isinstance(arg, IntVar): self.assert_is_boolean_variable(arg) return arg.index - if isinstance(arg, _NotBooleanVariable): + if isinstance(arg, swig_helper.NotBooleanVariable): self.assert_is_boolean_variable(arg.negated()) return arg.index if isinstance(arg, IntegralTypes): @@ -2814,7 +2087,7 @@ def get_or_make_boolean_index(self, arg: LiteralT) -> int: def get_interval_index(self, arg: IntervalVar) -> int: if not isinstance(arg, IntervalVar): - raise TypeError("NotSupported: model.get_interval_index(%s)" % arg) + raise TypeError(f"NotSupported: model.get_interval_index({arg})") return arg.index def get_or_make_index_from_constant(self, value: IntegralT) -> int: @@ -2850,52 +2123,41 @@ def parse_linear_expression( result.coeffs.append(mult) return result - coeffs_map, constant = cast(LinearExpr, linear_expr).get_integer_var_value_map() - result.offset = constant * mult - for t in coeffs_map.items(): - if not isinstance(t[0], IntVar): - raise TypeError("Wrong argument" + str(t)) - result.vars.append(t[0].index) - result.coeffs.append(t[1] * mult) + ble = BoundedLinearExpression(linear_expr, Domain.all_values()) + result.offset = ble.offset + for var in ble.vars: + result.vars.append(var.index) + for coeff in ble.coeffs: + result.coeffs.append(coeff * mult) return result def _set_objective(self, obj: ObjLinearExprT, minimize: bool): """Sets the objective of the model.""" self.clear_objective() - if isinstance(obj, IntVar): - self.__model.objective.vars.append(obj.index) - self.__model.objective.offset = 0 - if minimize: - self.__model.objective.coeffs.append(1) - self.__model.objective.scaling_factor = 1 - else: - self.__model.objective.coeffs.append(-1) - self.__model.objective.scaling_factor = -1 - elif isinstance(obj, LinearExpr): - coeffs_map, constant, is_integer = obj.get_float_var_value_map() - if is_integer: + if isinstance(obj, IntegralTypes): + self.__model.objective.offset = int(obj) + self.__model.objective.scaling_factor = 1.0 + elif isinstance(obj, FloatLinearExpr): + if obj.is_integer(): + ble = BoundedLinearExpression(obj, Domain.all_values()) + for var in ble.vars: + self.__model.objective.vars.append(var.index) if minimize: - self.__model.objective.scaling_factor = 1 - self.__model.objective.offset = constant + self.__model.objective.scaling_factor = 1.0 + self.__model.objective.offset = ble.offset + self.__model.objective.coeffs.extend(ble.coeffs) else: - self.__model.objective.scaling_factor = -1 - self.__model.objective.offset = -constant - for v, c in coeffs_map.items(): - c_as_int = int(c) - self.__model.objective.vars.append(v.index) - if minimize: - self.__model.objective.coeffs.append(c_as_int) - else: - self.__model.objective.coeffs.append(-c_as_int) + self.__model.objective.scaling_factor = -1.0 + self.__model.objective.offset = -ble.offset + for c in ble.coeffs: + self.__model.objective.coeffs.append(-c) else: + flat_obj = swig_helper.CanonicalFloatExpression(obj) + for var in flat_obj.vars: + self.__model.floating_point_objective.vars.append(var.index) + self.__model.floating_point_objective.coeffs.extend(flat_obj.coeffs) self.__model.floating_point_objective.maximize = not minimize - self.__model.floating_point_objective.offset = constant - for v, c in coeffs_map.items(): - self.__model.floating_point_objective.coeffs.append(c) - self.__model.floating_point_objective.vars.append(v.index) - elif isinstance(obj, IntegralTypes): - self.__model.objective.offset = int(obj) - self.__model.objective.scaling_factor = 1 + self.__model.floating_point_objective.offset = flat_obj.offset else: raise TypeError("TypeError: " + str(obj) + " is not a valid objective") @@ -3008,7 +2270,7 @@ def assert_is_boolean_variable(self, x: LiteralT) -> None: var = self.__model.variables[x.index] if len(var.domain) != 2 or var.domain[0] < 0 or var.domain[1] > 1: raise TypeError("TypeError: " + str(x) + " is not a boolean variable") - elif not isinstance(x, _NotBooleanVariable): + elif not isinstance(x, swig_helper.NotBooleanVariable): raise TypeError("TypeError: " + str(x) + " is not a boolean variable") # Compatibility with pre PEP8 @@ -3108,60 +2370,6 @@ def expand_generator_or_tuple(args): return args[0] -def evaluate_linear_expr( - expression: LinearExprT, solution: cp_model_pb2.CpSolverResponse -) -> int: - """Evaluate a linear expression against a solution.""" - if isinstance(expression, IntegralTypes): - return int(expression) - if not isinstance(expression, LinearExpr): - raise TypeError("Cannot interpret %s as a linear expression." % expression) - - value = 0 - to_process = [(expression, 1)] - while to_process: - expr, coeff = to_process.pop() - if isinstance(expr, IntegralTypes): - value += int(expr) * coeff - elif isinstance(expr, _ProductCst): - to_process.append((expr.expression(), coeff * expr.coefficient())) - elif isinstance(expr, _Sum): - to_process.append((expr.left(), coeff)) - to_process.append((expr.right(), coeff)) - elif isinstance(expr, _SumArray): - for e in expr.expressions(): - to_process.append((e, coeff)) - value += expr.constant() * coeff - elif isinstance(expr, _WeightedSum): - for e, c in zip(expr.expressions(), expr.coefficients()): - to_process.append((e, coeff * c)) - value += expr.constant() * coeff - elif isinstance(expr, IntVar): - value += coeff * solution.solution[expr.index] - elif isinstance(expr, _NotBooleanVariable): - value += coeff * (1 - solution.solution[expr.negated().index]) - else: - raise TypeError(f"Cannot interpret {expr} as a linear expression.") - - return value - - -def evaluate_boolean_expression( - literal: LiteralT, solution: cp_model_pb2.CpSolverResponse -) -> bool: - """Evaluate a boolean expression against a solution.""" - if isinstance(literal, IntegralTypes): - return bool(literal) - elif isinstance(literal, IntVar) or isinstance(literal, _NotBooleanVariable): - index: int = cast(Union[IntVar, _NotBooleanVariable], literal).index - if index >= 0: - return bool(solution.solution[index]) - else: - return not solution.solution[-index - 1] - else: - raise TypeError(f"Cannot interpret {literal} as a boolean expression.") - - class CpSolver: """Main solver class. @@ -3174,7 +2382,7 @@ class CpSolver: """ def __init__(self) -> None: - self.__solution: Optional[cp_model_pb2.CpSolverResponse] = None + self.__response_wrapper: Optional[swig_helper.ResponseWrapper] = None self.parameters: sat_parameters_pb2.SatParameters = ( sat_parameters_pb2.SatParameters() ) @@ -3202,10 +2410,7 @@ def solve( if self.best_bound_callback is not None: self.__solve_wrapper.add_best_bound_callback(self.best_bound_callback) - solution: cp_model_pb2.CpSolverResponse = self.__solve_wrapper.solve( - model.proto - ) - self.__solution = solution + self.__response_wrapper = self.__solve_wrapper.solve(model.proto) if solution_callback is not None: self.__solve_wrapper.clear_solution_callback(solution_callback) @@ -3213,7 +2418,7 @@ def solve( with self.__lock: self.__solve_wrapper = None - return solution.status + return self.__response_wrapper.status() def stop_search(self) -> None: """Stops the current search asynchronously.""" @@ -3223,7 +2428,7 @@ def stop_search(self) -> None: def value(self, expression: LinearExprT) -> int: """Returns the value of a linear expression after solve.""" - return evaluate_linear_expr(expression, self._solution) + return self._checked_response.value(expression) def values(self, variables: _IndexOrSeries) -> pd.Series: """Returns the values of the input variables. @@ -3239,16 +2444,20 @@ def values(self, variables: _IndexOrSeries) -> pd.Series: Returns: pd.Series: The values of all variables in the set. + + Raises: + RuntimeError: if solve() has not been called. """ - solution = self._solution - return _attribute_series( - func=lambda v: solution.solution[v.index], - values=variables, + if self.__response_wrapper is None: + raise RuntimeError("solve() has not been called.") + return pd.Series( + data=[self.__response_wrapper.value(var) for var in variables], + index=_get_index(variables), ) def boolean_value(self, literal: LiteralT) -> bool: """Returns the boolean value of a literal after solve.""" - return evaluate_boolean_expression(literal, self._solution) + return self._checked_response.boolean_value(literal) def boolean_values(self, variables: _IndexOrSeries) -> pd.Series: """Returns the values of the input variables. @@ -3264,65 +2473,71 @@ def boolean_values(self, variables: _IndexOrSeries) -> pd.Series: Returns: pd.Series: The values of all variables in the set. + + Raises: + RuntimeError: if solve() has not been called. """ - solution = self._solution - return _attribute_series( - func=lambda literal: evaluate_boolean_expression(literal, solution), - values=variables, + if self.__response_wrapper is None: + raise RuntimeError("solve() has not been called.") + return pd.Series( + data=[ + self.__response_wrapper.boolean_value(literal) for literal in variables + ], + index=_get_index(variables), ) @property def objective_value(self) -> float: """Returns the value of the objective after solve.""" - return self._solution.objective_value + return self._checked_response.objective_value() @property def best_objective_bound(self) -> float: """Returns the best lower (upper) bound found when min(max)imizing.""" - return self._solution.best_objective_bound + return self._checked_response.best_objective_bound() @property def num_booleans(self) -> int: """Returns the number of boolean variables managed by the SAT solver.""" - return self._solution.num_booleans + return self._checked_response.num_booleans() @property def num_conflicts(self) -> int: """Returns the number of conflicts since the creation of the solver.""" - return self._solution.num_conflicts + return self._checked_response.num_conflicts() @property def num_branches(self) -> int: """Returns the number of search branches explored by the solver.""" - return self._solution.num_branches + return self._checked_response.num_branches() @property def wall_time(self) -> float: """Returns the wall time in seconds since the creation of the solver.""" - return self._solution.wall_time + return self._checked_response.wall_time() @property def user_time(self) -> float: """Returns the user time in seconds since the creation of the solver.""" - return self._solution.user_time + return self._checked_response.user_time() @property def response_proto(self) -> cp_model_pb2.CpSolverResponse: """Returns the response object.""" - return self._solution + return self._checked_response.response() def response_stats(self) -> str: """Returns some statistics on the solution found as a string.""" - return swig_helper.CpSatHelper.solver_response_stats(self._solution) + return self._checked_response.response_stats() def sufficient_assumptions_for_infeasibility(self) -> Sequence[int]: """Returns the indices of the infeasible assumptions.""" - return self._solution.sufficient_assumptions_for_infeasibility + return self._checked_response.sufficient_assumptions_for_infeasibility() def status_name(self, status: Optional[Any] = None) -> str: """Returns the name of the status returned by solve().""" if status is None: - status = self._solution.status + status = self._checked_response.status() return cp_model_pb2.CpSolverStatus.Name(status) def solution_info(self) -> str: @@ -3334,14 +2549,14 @@ def solution_info(self) -> str: Raises: RuntimeError: if solve() has not been called. """ - return self._solution.solution_info + return self._checked_response.solution_info() @property - def _solution(self) -> cp_model_pb2.CpSolverResponse: - """Checks solve() has been called, and returns the solution.""" - if self.__solution is None: + def _checked_response(self) -> swig_helper.ResponseWrapper: + """Checks solve() has been called, and returns a response wrapper.""" + if self.__response_wrapper is None: raise RuntimeError("solve() has not been called.") - return self.__solution + return self.__response_wrapper # Compatibility with pre PEP8 # pylint: disable=invalid-name @@ -3496,15 +2711,7 @@ def boolean_value(self, lit: LiteralT) -> bool: """ if not self.has_response(): raise RuntimeError("solve() has not been called.") - if isinstance(lit, IntegralTypes): - return bool(lit) - if isinstance(lit, IntVar) or isinstance(lit, _NotBooleanVariable): - return self.SolutionBooleanValue( - cast(Union[IntVar, _NotBooleanVariable], lit).index - ) - if cmh.is_boolean(lit): - return bool(lit) - raise TypeError(f"Cannot interpret {lit} as a boolean expression.") + return self.BooleanValue(lit) def value(self, expression: LinearExprT) -> int: """Evaluates an linear expression in the current solution. @@ -3521,36 +2728,7 @@ def value(self, expression: LinearExprT) -> int: """ if not self.has_response(): raise RuntimeError("solve() has not been called.") - - value: int = 0 - to_process: list[tuple[LinearExprT, int]] = [(expression, 1)] - while to_process: - expr, coeff = to_process.pop() - if isinstance(expr, IntegralTypes): - value += int(expr) * coeff - elif isinstance(expr, _ProductCst): - to_process.append((expr.expression(), coeff * expr.coefficient())) - elif isinstance(expr, _Sum): - to_process.append((expr.left(), coeff)) - to_process.append((expr.right(), coeff)) - elif isinstance(expr, _SumArray): - for e in expr.expressions(): - to_process.append((e, coeff)) - value += expr.constant() * coeff - elif isinstance(expr, _WeightedSum): - for e, c in zip(expr.expressions(), expr.coefficients()): - to_process.append((e, coeff * c)) - value += expr.constant() * coeff - elif isinstance(expr, IntVar): - value += coeff * self.SolutionIntegerValue(expr.index) - elif isinstance(expr, _NotBooleanVariable): - value += coeff * (1 - self.SolutionIntegerValue(expr.negated().index)) - else: - raise TypeError( - f"cannot interpret {expression} as a linear expression." - ) - - return value + return self.Value(expression) def has_response(self) -> bool: return self.HasResponse() @@ -3638,12 +2816,6 @@ def response_proto(self) -> cp_model_pb2.CpSolverResponse: raise RuntimeError("solve() has not been called.") return self.Response() - # Compatibility with pre PEP8 - # pylint: disable=invalid-name - Value = value - BooleanValue = boolean_value - # pylint: enable=invalid-name - class ObjectiveSolutionPrinter(CpSolverSolutionCallback): """Display the objective value and time of intermediate solutions.""" @@ -3730,26 +2902,6 @@ def _get_index(obj: _IndexOrSeries) -> pd.Index: return obj -def _attribute_series( - *, - func: Callable[[IntVar], IntegralT], - values: _IndexOrSeries, -) -> pd.Series: - """Returns the attributes of `values`. - - Args: - func: The function to call for getting the attribute data. - values: The values that the function will be applied (element-wise) to. - - Returns: - pd.Series: The attribute values. - """ - return pd.Series( - data=[func(v) for v in values], - index=_get_index(values), - ) - - def _convert_to_integral_series_and_validate_index( value_or_series: Union[IntegralT, pd.Series], index: pd.Index ) -> pd.Series: diff --git a/ortools/sat/python/cp_model_test.py b/ortools/sat/python/cp_model_test.py index a99b7a58ac6..ff9027efd77 100644 --- a/ortools/sat/python/cp_model_test.py +++ b/ortools/sat/python/cp_model_test.py @@ -366,30 +366,6 @@ def testSimplification1(self) -> None: self.assertEqual(x, prod.expression()) self.assertEqual(4, prod.coefficient()) - def testSimplification2(self) -> None: - print("testSimplification2") - model = cp_model.CpModel() - x = model.new_int_var(-10, 10, "x") - prod = 2 * (x * 2) - self.assertEqual(x, prod.expression()) - self.assertEqual(4, prod.coefficient()) - - def testSimplification3(self) -> None: - print("testSimplification3") - model = cp_model.CpModel() - x = model.new_int_var(-10, 10, "x") - prod = (2 * x) * 2 - self.assertEqual(x, prod.expression()) - self.assertEqual(4, prod.coefficient()) - - def testSimplification4(self) -> None: - print("testSimplification4") - model = cp_model.CpModel() - x = model.new_int_var(-10, 10, "x") - prod = 2 * (2 * x) - self.assertEqual(x, prod.expression()) - self.assertEqual(4, prod.coefficient()) - def testLinearNonEqualWithConstant(self) -> None: print("testLinearNonEqualWithConstant") model = cp_model.CpModel() @@ -459,7 +435,7 @@ def testNaturalApiMaximizeFloat(self) -> None: self.assertEqual(16.1, solver.objective_value) def testNaturalApiMaximizeComplex(self) -> None: - print("testNaturalApiMaximizeFloat") + print("testNaturalApiMaximizeComplex") model = cp_model.CpModel() x1 = model.new_bool_var("x1") x2 = model.new_bool_var("x1") @@ -1169,7 +1145,7 @@ def testStr(self) -> None: self.assertEqual(str(x < 2), "x <= 1") self.assertEqual(str(x != 2), "x != 2") self.assertEqual(str(x * 3), "(3 * x)") - self.assertEqual(str(-x), "-x") + self.assertEqual(str(-x), "(-x)") self.assertEqual(str(x + 3), "(x + 3)") self.assertEqual(str(x <= cp_model.INT_MAX), "True (unbounded expr x)") self.assertEqual(str(x != 9223372036854775807), "x <= 9223372036854775806") @@ -1177,14 +1153,14 @@ def testStr(self) -> None: y = model.new_int_var(0, 4, "y") self.assertEqual( str(cp_model.LinearExpr.weighted_sum([x, y + 1, 2], [1, -2, 3])), - "x - 2 * (y + 1) + 6", + "(x - 2 * (y + 1) + 6)", ) self.assertEqual(str(cp_model.LinearExpr.term(x, 3)), "(3 * x)") - self.assertEqual(str(x != y), "(x + -y) != 0") + self.assertEqual(str(x != y), "(x - y) != 0") self.assertEqual( - "0 <= x <= 10", str(cp_model.BoundedLinearExpression(x, [0, 10])) + "0 <= x <= 10", + str(cp_model.BoundedLinearExpression(x, cp_model.Domain(0, 10))), ) - print(str(model)) b = model.new_bool_var("b") self.assertEqual(str(cp_model.LinearExpr.term(b.negated(), 3)), "(3 * not(b))") @@ -1198,15 +1174,15 @@ def testRepr(self) -> None: y = model.new_int_var(0, 3, "y") z = model.new_int_var(0, 3, "z") self.assertEqual(repr(x), "x(0..4)") - self.assertEqual(repr(x * 2), "ProductCst(x(0..4), 2)") - self.assertEqual(repr(x + y), "sum(x(0..4), y(0..3))") + self.assertEqual(repr(x * 2), "IntAffine(expr=x(0..4), coeff=2, offset=0)") + self.assertEqual(repr(x + y), "IntSum(x(0..4), y(0..3), 0)") self.assertEqual( repr(cp_model.LinearExpr.sum([x, y, z])), - "SumArray(x(0..4), y(0..3), z(0..3), 0)", + "IntSum(x(0..4), y(0..3), z(0..3), 0)", ) self.assertEqual( repr(cp_model.LinearExpr.weighted_sum([x, y, 2], [1, 2, 3])), - "weighted_sum([x(0..4), y(0..3)], [1, 2], 6)", + "IntWeightedSum([x(0..4), y(0..3)], [1, 2], 6)", ) i = model.new_interval_var(x, 2, y, "i") self.assertEqual(repr(i), "i(start = x, size = 2, end = y)") diff --git a/ortools/sat/python/linear_expr.cc b/ortools/sat/python/linear_expr.cc new file mode 100644 index 00000000000..526258ca804 --- /dev/null +++ b/ortools/sat/python/linear_expr.cc @@ -0,0 +1,609 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/python/linear_expr.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "ortools/util/sorted_interval_list.h" + +namespace operations_research { +namespace sat { +namespace python { + +FloatLinearExpr* FloatLinearExpr::Sum( + const std::vector& exprs) { + return Sum(exprs, 0.0); +} + +FloatLinearExpr* FloatLinearExpr::Sum( + const std::vector& exprs, double cst) { + std::vector lin_exprs; + for (const FloatExprOrValue& choice : exprs) { + if (choice.expr != nullptr) { + lin_exprs.push_back(choice.expr); + } else { + cst += choice.value; + } + } + if (lin_exprs.empty()) return new FloatConstant(cst); + if (lin_exprs.size() == 1) return Affine(lin_exprs[0], 1.0, cst); + return new FloatWeightedSum(lin_exprs, cst); +} + +FloatLinearExpr* FloatLinearExpr::WeightedSum( + const std::vector& exprs, + const std::vector& coeffs) { + return WeightedSum(exprs, coeffs, 0.0); +} + +FloatLinearExpr* FloatLinearExpr::WeightedSum( + const std::vector& exprs, + const std::vector& coeffs, double cst) { + std::vector lin_exprs; + std::vector lin_coeffs; + for (int i = 0; i < exprs.size(); ++i) { + if (exprs[i].expr != nullptr) { + lin_exprs.push_back(exprs[i].expr); + lin_coeffs.push_back(coeffs[i]); + } else { + cst += exprs[i].value * coeffs[i]; + } + } + + if (lin_exprs.empty()) return new FloatConstant(cst); + if (lin_exprs.size() == 1) { + return Affine(lin_exprs[0], lin_coeffs[0], cst); + } + return new FloatWeightedSum(lin_exprs, lin_coeffs, cst); +} + +FloatLinearExpr* FloatLinearExpr::Term(FloatLinearExpr* expr, double coeff) { + return new FloatAffine(expr, coeff, 0.0); +} + +FloatLinearExpr* FloatLinearExpr::Affine(FloatLinearExpr* expr, double coeff, + double offset) { + return new FloatAffine(expr, coeff, offset); +} + +FloatLinearExpr* FloatLinearExpr::Constant(double value) { + return new FloatConstant(value); +} + +FloatLinearExpr* FloatLinearExpr::FloatAddCst(double cst) { + if (cst == 0.0) return this; + return new FloatAffine(this, 1.0, cst); +} + +FloatLinearExpr* FloatLinearExpr::FloatAdd(FloatLinearExpr* other) { + std::vector exprs; + exprs.push_back(this); + exprs.push_back(other); + return new FloatWeightedSum(exprs, 0); +} + +FloatLinearExpr* FloatLinearExpr::FloatSubCst(double cst) { + if (cst == 0.0) return this; + return new FloatAffine(this, 1.0, -cst); +} + +FloatLinearExpr* FloatLinearExpr::FloatSub(FloatLinearExpr* other) { + std::vector exprs; + exprs.push_back(this); + exprs.push_back(other); + return new FloatWeightedSum(exprs, {1, -1}, 0); +} + +FloatLinearExpr* FloatLinearExpr::FloatRSub(FloatLinearExpr* other) { + std::vector exprs; + exprs.push_back(this); + exprs.push_back(other); + return new FloatWeightedSum(exprs, {-1, 1}, 0); +} + +FloatLinearExpr* FloatLinearExpr::FloatRSubCst(double cst) { + return new FloatAffine(this, -1.0, cst); +} + +FloatLinearExpr* FloatLinearExpr::FloatMulCst(double cst) { + if (cst == 0.0) return Sum({}); + if (cst == 1.0) return this; + return new FloatAffine(this, cst, 0.0); +} + +FloatLinearExpr* FloatLinearExpr::FloatNeg() { + return new FloatAffine(this, -1.0, 0.0); +} + +void FloatExprVisitor::AddToProcess(FloatLinearExpr* expr, double coeff) { + to_process_.push_back(std::make_pair(expr, coeff)); +} +void FloatExprVisitor::AddConstant(double constant) { offset_ += constant; } +void FloatExprVisitor::AddVarCoeff(BaseIntVar* var, double coeff) { + canonical_terms_[var] += coeff; +} +double FloatExprVisitor::Process(FloatLinearExpr* expr, + std::vector* vars, + std::vector* coeffs) { + AddToProcess(expr, 1.0); + while (!to_process_.empty()) { + const auto [expr, coeff] = to_process_.back(); + to_process_.pop_back(); + expr->VisitAsFloat(this, coeff); + } + + vars->clear(); + coeffs->clear(); + for (const auto& [var, coeff] : canonical_terms_) { + if (coeff == 0) continue; + vars->push_back(var); + coeffs->push_back(coeff); + } + + return offset_; +} + +CanonicalFloatExpression::CanonicalFloatExpression(FloatLinearExpr* expr) { + FloatExprVisitor lin; + offset_ = lin.Process(expr, &vars_, &coeffs_); +} + +void FloatConstant::VisitAsFloat(FloatExprVisitor* lin, double c) { + lin->AddConstant(value_ * c); +} + +std::string FloatConstant::ToString() const { return absl::StrCat(value_); } + +std::string FloatConstant::DebugString() const { + return absl::StrCat("FloatConstant(", value_, ")"); +} + +FloatWeightedSum::FloatWeightedSum(const std::vector& exprs, + double offset) + : exprs_(exprs.begin(), exprs.end()), + coeffs_(exprs.size(), 1), + offset_(offset) {} + +FloatWeightedSum::FloatWeightedSum(const std::vector& exprs, + const std::vector& coeffs, + double offset) + : exprs_(exprs.begin(), exprs.end()), + coeffs_(coeffs.begin(), coeffs.end()), + offset_(offset) {} + +void FloatWeightedSum::VisitAsFloat(FloatExprVisitor* lin, double c) { + for (int i = 0; i < exprs_.size(); ++i) { + lin->AddToProcess(exprs_[i], coeffs_[i] * c); + } + lin->AddConstant(offset_ * c); +} + +std::string FloatWeightedSum::ToString() const { + if (exprs_.empty()) { + return absl::StrCat(offset_); + } + std::string s = "("; + bool first_printed = true; + for (int i = 0; i < exprs_.size(); ++i) { + if (coeffs_[i] == 0.0) continue; + if (first_printed) { + first_printed = false; + if (coeffs_[i] == 1.0) { + absl::StrAppend(&s, exprs_[i]->ToString()); + } else if (coeffs_[i] == -1.0) { + absl::StrAppend(&s, "-", exprs_[i]->ToString()); + } else { + absl::StrAppend(&s, coeffs_[i], " * ", exprs_[i]->ToString()); + } + } else { + if (coeffs_[i] == 1.0) { + absl::StrAppend(&s, " + ", exprs_[i]->ToString()); + } else if (coeffs_[i] == -1.0) { + absl::StrAppend(&s, " - ", exprs_[i]->ToString()); + } else if (coeffs_[i] > 0.0) { + absl::StrAppend(&s, " + ", coeffs_[i], " * ", exprs_[i]->ToString()); + } else { + absl::StrAppend(&s, " - ", -coeffs_[i], " * ", exprs_[i]->ToString()); + } + } + } + // If there are no terms, just print the offset. + if (first_printed) { + return absl::StrCat(offset_); + } + + // If there is an offset, print it. + if (offset_ != 0.0) { + if (offset_ > 0.0) { + absl::StrAppend(&s, " + ", offset_); + } else { + absl::StrAppend(&s, " - ", -offset_); + } + } + absl::StrAppend(&s, ")"); + return s; +} + +FloatAffine::FloatAffine(FloatLinearExpr* expr, double coeff, double offset) + : expr_(expr), coeff_(coeff), offset_(offset) {} + +void FloatAffine::VisitAsFloat(FloatExprVisitor* lin, double c) { + lin->AddToProcess(expr_, c * coeff_); + lin->AddConstant(offset_ * c); +} + +std::string FloatAffine::ToString() const { + std::string s = "("; + if (coeff_ == 1.0) { + absl::StrAppend(&s, expr_->ToString()); + } else if (coeff_ == -1.0) { + absl::StrAppend(&s, "-", expr_->ToString()); + } else { + absl::StrAppend(&s, coeff_, " * ", expr_->ToString()); + } + if (offset_ > 0.0) { + absl::StrAppend(&s, " + ", offset_); + } else if (offset_ < 0.0) { + absl::StrAppend(&s, " - ", -offset_); + } + absl::StrAppend(&s, ")"); + return s; +} + +std::string FloatAffine::DebugString() const { + return absl::StrCat("FloatAffine(expr=", expr_->DebugString(), + ", coeff=", coeff_, ", offset=", offset_, ")"); +} + +IntLinExpr* IntLinExpr::Sum(const std::vector& exprs) { + return Sum(exprs, 0); +} + +IntLinExpr* IntLinExpr::Sum(const std::vector& exprs, + int64_t cst) { + std::vector lin_exprs; + for (const IntExprOrValue& choice : exprs) { + if (choice.expr != nullptr) { + lin_exprs.push_back(choice.expr); + } else { + cst += choice.value; + } + } + if (lin_exprs.empty()) return new IntConstant(cst); + if (lin_exprs.size() == 1) return Affine(lin_exprs[0], 1, cst); + return new IntSum(lin_exprs, cst); +} + +IntLinExpr* IntLinExpr::WeightedSum(const std::vector& exprs, + const std::vector& coeffs) { + return WeightedSum(exprs, coeffs, 0); +} + +IntLinExpr* IntLinExpr::WeightedSum(const std::vector& exprs, + const std::vector& coeffs, + int64_t cst) { + std::vector lin_exprs; + std::vector lin_coeffs; + for (int i = 0; i < exprs.size(); ++i) { + if (exprs[i].expr != nullptr) { + lin_exprs.push_back(exprs[i].expr); + lin_coeffs.push_back(coeffs[i]); + } else { + cst += exprs[i].value * coeffs[i]; + } + } + if (lin_exprs.empty()) return new IntConstant(cst); + if (lin_exprs.size() == 1) { + return IntLinExpr::Affine(lin_exprs[0], lin_coeffs[0], cst); + } + return new IntWeightedSum(lin_exprs, lin_coeffs, cst); +} + +IntLinExpr* IntLinExpr::Term(IntLinExpr* expr, int64_t coeff) { + return Affine(expr, coeff, 0); +} + +IntLinExpr* IntLinExpr::Affine(IntLinExpr* expr, int64_t coeff, + int64_t offset) { + if (coeff == 1 && offset == 0) return expr; + if (coeff == 0) return new IntConstant(offset); + return new IntAffine(expr, coeff, offset); +} + +IntLinExpr* IntLinExpr::Constant(int64_t value) { + return new IntConstant(value); +} + +IntLinExpr* IntLinExpr::IntAddCst(int64_t cst) { + if (cst == 0) return this; + return new IntAffine(this, 1, cst); +} + +IntLinExpr* IntLinExpr::IntAdd(IntLinExpr* other) { + std::vector exprs; + exprs.push_back(this); + exprs.push_back(other); + return new IntSum(exprs, 0); +} + +IntLinExpr* IntLinExpr::IntSubCst(int64_t cst) { + if (cst == 0) return this; + return new IntAffine(this, 1, -cst); +} + +IntLinExpr* IntLinExpr::IntSub(IntLinExpr* other) { + std::vector exprs; + exprs.push_back(this); + exprs.push_back(other); + return new IntWeightedSum(exprs, {1, -1}, 0); +} + +IntLinExpr* IntLinExpr::IntRSubCst(int64_t cst) { + return new IntAffine(this, -1, cst); +} + +IntLinExpr* IntLinExpr::IntMulCst(int64_t cst) { + if (cst == 0) return new IntConstant(0); + if (cst == 1) return this; + return new IntAffine(this, cst, 0); +} + +IntLinExpr* IntLinExpr::IntNeg() { return new IntAffine(this, -1, 0); } + +BoundedLinearExpression* IntLinExpr::Eq(IntLinExpr* other) { + return new BoundedLinearExpression(this, other, Domain(0)); +} + +BoundedLinearExpression* IntLinExpr::EqCst(int64_t cst) { + return new BoundedLinearExpression(this, Domain(cst)); +} + +BoundedLinearExpression* IntLinExpr::Ne(IntLinExpr* other) { + return new BoundedLinearExpression(this, other, Domain(0).Complement()); +} + +BoundedLinearExpression* IntLinExpr::NeCst(int64_t cst) { + return new BoundedLinearExpression(this, Domain(cst).Complement()); +} + +BoundedLinearExpression* IntLinExpr::Le(IntLinExpr* other) { + return new BoundedLinearExpression( + this, other, Domain(std::numeric_limits::min(), 0)); +} + +BoundedLinearExpression* IntLinExpr::LeCst(int64_t cst) { + return new BoundedLinearExpression( + this, Domain(std::numeric_limits::min(), cst)); +} + +BoundedLinearExpression* IntLinExpr::Lt(IntLinExpr* other) { + return new BoundedLinearExpression( + this, other, Domain(std::numeric_limits::min(), -1)); +} + +BoundedLinearExpression* IntLinExpr::LtCst(int64_t cst) { + return new BoundedLinearExpression( + this, Domain(std::numeric_limits::min(), cst - 1)); +} + +BoundedLinearExpression* IntLinExpr::Ge(IntLinExpr* other) { + return new BoundedLinearExpression( + this, other, Domain(0, std::numeric_limits::max())); +} + +BoundedLinearExpression* IntLinExpr::GeCst(int64_t cst) { + return new BoundedLinearExpression( + this, Domain(cst, std::numeric_limits::max())); +} + +BoundedLinearExpression* IntLinExpr::Gt(IntLinExpr* other) { + return new BoundedLinearExpression( + this, other, Domain(1, std::numeric_limits::max())); +} + +BoundedLinearExpression* IntLinExpr::GtCst(int64_t cst) { + return new BoundedLinearExpression( + this, Domain(cst + 1, std::numeric_limits::max())); +} + +void IntExprVisitor::AddToProcess(IntLinExpr* expr, int64_t coeff) { + to_process_.push_back(std::make_pair(expr, coeff)); +} + +void IntExprVisitor::AddConstant(int64_t constant) { offset_ += constant; } + +void IntExprVisitor::AddVarCoeff(BaseIntVar* var, int64_t coeff) { + canonical_terms_[var] += coeff; +} + +void IntExprVisitor::ProcessAll() { + while (!to_process_.empty()) { + const auto [expr, coeff] = to_process_.back(); + to_process_.pop_back(); + expr->VisitAsInt(this, coeff); + } +} + +int64_t IntExprVisitor::Process(std::vector* vars, + std::vector* coeffs) { + ProcessAll(); + vars->clear(); + coeffs->clear(); + for (const auto& [var, coeff] : canonical_terms_) { + if (coeff == 0) continue; + vars->push_back(var); + coeffs->push_back(coeff); + } + + return offset_; +} + +int64_t IntExprVisitor::Evaluate(IntLinExpr* expr, + const CpSolverResponse& solution) { + AddToProcess(expr, 1); + ProcessAll(); + int64_t value = offset_; + for (const auto& [var, coeff] : canonical_terms_) { + if (coeff == 0) continue; + value += coeff * solution.solution(var->index()); + } + return value; +} + +bool BaseIntVarComparator::operator()(const BaseIntVar* lhs, + const BaseIntVar* rhs) const { + return lhs->index() < rhs->index(); +} + +BoundedLinearExpression::BoundedLinearExpression(IntLinExpr* expr, + const Domain& bounds) + : bounds_(bounds) { + IntExprVisitor lin; + lin.AddToProcess(expr, 1); + offset_ = lin.Process(&vars_, &coeffs_); +} + +BoundedLinearExpression::BoundedLinearExpression(IntLinExpr* pos, + IntLinExpr* neg, + const Domain& bounds) + : bounds_(bounds) { + IntExprVisitor lin; + lin.AddToProcess(pos, 1); + lin.AddToProcess(neg, -1); + offset_ = lin.Process(&vars_, &coeffs_); +} + +BoundedLinearExpression::BoundedLinearExpression(int64_t offset, + const Domain& bounds) + : bounds_(bounds), offset_(offset) {} + +const Domain& BoundedLinearExpression::bounds() const { return bounds_; } +const std::vector& BoundedLinearExpression::vars() const { + return vars_; +} +const std::vector& BoundedLinearExpression::coeffs() const { + return coeffs_; +} +int64_t BoundedLinearExpression::offset() const { return offset_; } + +std::string BoundedLinearExpression::ToString() const { + std::string s; + if (vars_.empty()) { + s = absl::StrCat(offset_); + } else if (vars_.size() == 1) { + const std::string var_name = vars_[0]->ToString(); + if (coeffs_[0] == 1) { + s = var_name; + } else if (coeffs_[0] == -1) { + s = absl::StrCat("-", var_name); + } else { + s = absl::StrCat(coeffs_[0], " * ", var_name); + } + if (offset_ > 0) { + absl::StrAppend(&s, " + ", offset_); + } else if (offset_ < 0) { + absl::StrAppend(&s, " - ", -offset_); + } + } else { + s = "("; + for (int i = 0; i < vars_.size(); ++i) { + const std::string var_name = vars_[i]->ToString(); + if (i == 0) { + if (coeffs_[i] == 1) { + absl::StrAppend(&s, var_name); + } else if (coeffs_[i] == -1) { + absl::StrAppend(&s, "-", var_name); + } else { + absl::StrAppend(&s, coeffs_[i], " * ", var_name); + } + } else { + if (coeffs_[i] == 1) { + absl::StrAppend(&s, " + ", var_name); + } else if (coeffs_[i] == -1) { + absl::StrAppend(&s, " - ", var_name); + } else if (coeffs_[i] > 1) { + absl::StrAppend(&s, " + ", coeffs_[i], " * ", var_name); + } else { + absl::StrAppend(&s, " - ", -coeffs_[i], " * ", var_name); + } + } + } + if (offset_ > 0) { + absl::StrAppend(&s, " + ", offset_); + } else if (offset_ < 0) { + absl::StrAppend(&s, " - ", -offset_); + } + absl::StrAppend(&s, ")"); + } + if (bounds_.IsFixed()) { + absl::StrAppend(&s, " == ", bounds_.Min()); + } else if (bounds_.NumIntervals() == 1) { + if (bounds_.Min() == std::numeric_limits::min()) { + if (bounds_.Max() == std::numeric_limits::max()) { + return absl::StrCat("True (unbounded expr ", s, ")"); + } else { + absl::StrAppend(&s, " <= ", bounds_.Max()); + } + } else if (bounds_.Max() == std::numeric_limits::max()) { + absl::StrAppend(&s, " >= ", bounds_.Min()); + } else { + return absl::StrCat(bounds_.Min(), " <= ", s, " <= ", bounds_.Max()); + } + } else if (bounds_.Complement().IsFixed()) { + absl::StrAppend(&s, " != ", bounds_.Complement().Min()); + } else { + absl::StrAppend(&s, " in ", bounds_.ToString()); + } + return s; +} + +std::string BoundedLinearExpression::DebugString() const { + return absl::StrCat("BoundedLinearExpression(vars=[", + absl::StrJoin(vars_, ", ", + [](std::string* out, BaseIntVar* var) { + absl::StrAppend(out, var->DebugString()); + }), + "], coeffs=[", absl::StrJoin(coeffs_, ", "), + "], offset=", offset_, ", bounds=", bounds_.ToString(), + ")"); +} + +bool BoundedLinearExpression::CastToBool(bool* result) const { + const bool is_zero = bounds_.IsFixed() && bounds_.FixedValue() == 0; + const Domain complement = bounds_.Complement(); + const bool is_all_but_zero = + complement.IsFixed() && complement.FixedValue() == 0; + if (is_zero || is_all_but_zero) { + if (vars_.empty()) { + *result = is_zero; + return true; + } else if (vars_.size() == 2 && coeffs_[0] + coeffs_[1] == 0 && + std::abs(coeffs_[0]) == 1) { + *result = is_all_but_zero; + return true; + } + } + return false; +} + +} // namespace python +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/python/linear_expr.h b/ortools/sat/python/linear_expr.h new file mode 100644 index 00000000000..e15919bc654 --- /dev/null +++ b/ortools/sat/python/linear_expr.h @@ -0,0 +1,581 @@ +// Copyright 2010-2024 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_SAT_PYTHON_LINEAR_EXPR_H_ +#define OR_TOOLS_SAT_PYTHON_LINEAR_EXPR_H_ + +#include +#include +#include +#include + +#include "absl/container/btree_map.h" +#include "absl/container/fixed_array.h" +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "ortools/sat/cp_model.pb.h" +#include "ortools/util/sorted_interval_list.h" + +namespace operations_research { +namespace sat { +namespace python { + +class BoundedLinearExpression; +class CanonicalFloatExpression; +class FloatExprVisitor; +class FloatLinearExpr; +class IntExprVisitor; +class IntLinExpr; +class BaseIntVar; +class NotBooleanVariable; + +// A class to hold an floating point linear expression or a double constant. +struct FloatExprOrValue { + explicit FloatExprOrValue(FloatLinearExpr* e) : expr(e) {} + explicit FloatExprOrValue(double v) : value(v) {} + + FloatLinearExpr* expr = nullptr; + double value = 0; +}; + +// A linear expression that can be either integer or floating point. +class FloatLinearExpr { + public: + virtual ~FloatLinearExpr() = default; + virtual void VisitAsFloat(FloatExprVisitor* /*lin*/, double /*c*/) {} + virtual bool is_integer() const { return false; } + virtual std::string ToString() const { return "FloatLinearExpr"; } + virtual std::string DebugString() const { return ToString(); } + + static FloatLinearExpr* Sum(const std::vector& exprs); + static FloatLinearExpr* Sum(const std::vector& exprs, + double cst); + static FloatLinearExpr* WeightedSum( + const std::vector& exprs, + const std::vector& coeffs); + static FloatLinearExpr* WeightedSum( + const std::vector& exprs, + const std::vector& coeffs, double cst); + static FloatLinearExpr* Term(FloatLinearExpr* expr, double coeff); + static FloatLinearExpr* Affine(FloatLinearExpr* expr, double coeff, + double offset); + static FloatLinearExpr* Constant(double value); + + FloatLinearExpr* FloatAddCst(double cst); + FloatLinearExpr* FloatAdd(FloatLinearExpr* other); + FloatLinearExpr* FloatSubCst(double cst); + FloatLinearExpr* FloatSub(FloatLinearExpr* other); + FloatLinearExpr* FloatRSub(FloatLinearExpr* other); + FloatLinearExpr* FloatRSubCst(double cst); + FloatLinearExpr* FloatMulCst(double cst); + FloatLinearExpr* FloatNeg(); +}; + +// Compare the indices of variables. +struct BaseIntVarComparator { + bool operator()(const BaseIntVar* lhs, const BaseIntVar* rhs) const; +}; + +// A visitor class to process a floating point linear expression. +class FloatExprVisitor { + public: + void AddToProcess(FloatLinearExpr* expr, double coeff); + void AddConstant(double constant); + void AddVarCoeff(BaseIntVar* var, double coeff); + double Process(FloatLinearExpr* expr, std::vector* vars, + std::vector* coeffs); + + private: + std::vector> to_process_; + absl::btree_map canonical_terms_; + double offset_ = 0; +}; + +// A class to build a canonical floating point linear expression. +class CanonicalFloatExpression { + public: + explicit CanonicalFloatExpression(FloatLinearExpr* expr); + const std::vector& vars() const { return vars_; } + const std::vector& coeffs() const { return coeffs_; } + double offset() const { return offset_; } + + private: + double offset_; + std::vector vars_; + std::vector coeffs_; +}; + +// A class to hold a constant. +class FloatConstant : public FloatLinearExpr { + public: + explicit FloatConstant(double value) : value_(value) {} + ~FloatConstant() override = default; + + void VisitAsFloat(FloatExprVisitor* lin, double c) override; + std::string ToString() const override; + std::string DebugString() const override; + + private: + double value_; +}; + +// A class to hold a weighted sum of floating point linear expressions. +class FloatWeightedSum : public FloatLinearExpr { + public: + FloatWeightedSum(const std::vector& exprs, double offset); + FloatWeightedSum(const std::vector& exprs, + const std::vector& coeffs, double offset); + ~FloatWeightedSum() override = default; + + void VisitAsFloat(FloatExprVisitor* lin, double c) override; + std::string ToString() const override; + + private: + const absl::FixedArray exprs_; + const absl::FixedArray coeffs_; + double offset_; +}; + +// A class to hold float_exr * a = b. +class FloatAffine : public FloatLinearExpr { + public: + FloatAffine(FloatLinearExpr* expr, double coeff, double offset); + ~FloatAffine() override = default; + + void VisitAsFloat(FloatExprVisitor* lin, double c) override; + std::string ToString() const override; + std::string DebugString() const override; + + FloatLinearExpr* expression() const { return expr_; } + double coefficient() const { return coeff_; } + double offset() const { return offset_; } + + private: + FloatLinearExpr* expr_; + double coeff_; + double offset_; +}; + +// A struct to hold an integer linear expression or an integer constant. +struct IntExprOrValue { + explicit IntExprOrValue(IntLinExpr* e) : expr(e) {} + explicit IntExprOrValue(int64_t v) : value(v) {} + + IntLinExpr* expr = nullptr; + int64_t value = 0; +}; + +class IntLinExpr : public FloatLinearExpr { + public: + ~IntLinExpr() override = default; + virtual void VisitAsInt(IntExprVisitor* /*lin*/, int64_t /*c*/) {} + bool is_integer() const override { return true; } + std::string ToString() const override { return "IntLinExpr"; } + + static IntLinExpr* Sum(const std::vector& exprs); + static IntLinExpr* Sum(const std::vector& exprs, int64_t cst); + static IntLinExpr* Sum(const std::vector& exprs, int64_t cst); + static IntLinExpr* Sum(const std::vector& exprs); + static IntLinExpr* WeightedSum(const std::vector& exprs, + const std::vector& coeffs); + static IntLinExpr* WeightedSum(const std::vector& exprs, + const std::vector& coeffs, + int64_t cst); + static IntLinExpr* Term(IntLinExpr* expr, int64_t coeff); + static IntLinExpr* Affine(IntLinExpr* expr, int64_t coeff, int64_t offset); + static IntLinExpr* Constant(int64_t value); + + IntLinExpr* IntAddCst(int64_t cst); + IntLinExpr* IntAdd(IntLinExpr* other); + IntLinExpr* IntSubCst(int64_t cst); + IntLinExpr* IntSub(IntLinExpr* other); + IntLinExpr* IntRSubCst(int64_t cst); + IntLinExpr* IntMulCst(int64_t cst); + IntLinExpr* IntNeg(); + + BoundedLinearExpression* EqCst(int64_t cst); + BoundedLinearExpression* NeCst(int64_t cst); + BoundedLinearExpression* GeCst(int64_t cst); + BoundedLinearExpression* LeCst(int64_t cst); + BoundedLinearExpression* LtCst(int64_t cst); + BoundedLinearExpression* GtCst(int64_t cst); + BoundedLinearExpression* Eq(IntLinExpr* other); + BoundedLinearExpression* Ne(IntLinExpr* other); + BoundedLinearExpression* Ge(IntLinExpr* other); + BoundedLinearExpression* Le(IntLinExpr* other); + BoundedLinearExpression* Lt(IntLinExpr* other); + BoundedLinearExpression* Gt(IntLinExpr* other); +}; + +// A visitor class to process an integer linear expression. +class IntExprVisitor { + public: + void AddToProcess(IntLinExpr* expr, int64_t coeff); + void AddConstant(int64_t constant); + void AddVarCoeff(BaseIntVar* var, int64_t coeff); + void ProcessAll(); + int64_t Process(std::vector* vars, std::vector* coeffs); + int64_t Evaluate(IntLinExpr* expr, const CpSolverResponse& solution); + + private: + std::vector> to_process_; + absl::btree_map canonical_terms_; + int64_t offset_ = 0; +}; + +// A class to hold a linear expression with bounds. +class BoundedLinearExpression { + public: + BoundedLinearExpression(IntLinExpr* expr, const Domain& bounds); + + BoundedLinearExpression(IntLinExpr* pos, IntLinExpr* neg, + const Domain& bounds); + BoundedLinearExpression(int64_t offset, const Domain& bounds); + ~BoundedLinearExpression() = default; + + const Domain& bounds() const; + const std::vector& vars() const; + const std::vector& coeffs() const; + int64_t offset() const; + std::string ToString() const; + std::string DebugString() const; + bool CastToBool(bool* result) const; + + private: + Domain bounds_; + int64_t offset_; + std::vector vars_; + std::vector coeffs_; +}; + +// A class to hold a constant. +class IntConstant : public IntLinExpr { + public: + explicit IntConstant(int64_t value) : value_(value) {} + ~IntConstant() override = default; + void VisitAsInt(IntExprVisitor* lin, int64_t c) override { + lin->AddConstant(value_ * c); + } + + void VisitAsFloat(FloatExprVisitor* lin, double c) override { + lin->AddConstant(value_ * c); + } + + std::string ToString() const override { return absl::StrCat(value_); } + + std::string DebugString() const override { + return absl::StrCat("IntConstant(", value_, ")"); + } + + private: + int64_t value_; +}; + +// A class to hold a sum of integer linear expressions. +class IntSum : public IntLinExpr { + public: + IntSum(const std::vector& exprs, int64_t offset) + : exprs_(exprs.begin(), exprs.end()), offset_(offset) {} + ~IntSum() override = default; + + void VisitAsInt(IntExprVisitor* lin, int64_t c) override { + for (int i = 0; i < exprs_.size(); ++i) { + lin->AddToProcess(exprs_[i], c); + } + lin->AddConstant(offset_ * c); + } + + void VisitAsFloat(FloatExprVisitor* lin, double c) override { + for (int i = 0; i < exprs_.size(); ++i) { + lin->AddToProcess(exprs_[i], c); + } + lin->AddConstant(offset_ * c); + } + + std::string ToString() const override { + if (exprs_.empty()) { + return absl::StrCat(offset_); + } + std::string s = "("; + for (int i = 0; i < exprs_.size(); ++i) { + if (i > 0) { + absl::StrAppend(&s, " + "); + } + absl::StrAppend(&s, exprs_[i]->ToString()); + } + if (offset_ != 0) { + if (offset_ > 0) { + absl::StrAppend(&s, " + ", offset_); + } else { + absl::StrAppend(&s, " - ", -offset_); + } + } + absl::StrAppend(&s, ")"); + return s; + } + + std::string DebugString() const override { + return absl::StrCat("IntSum(", + absl::StrJoin(exprs_, ", ", + [](std::string* out, IntLinExpr* expr) { + absl::StrAppend(out, + expr->DebugString()); + }), + ", ", offset_, ")"); + } + + private: + const absl::FixedArray exprs_; + int64_t offset_; +}; + +// A class to hold a weighted sum of integer linear expressions. +class IntWeightedSum : public IntLinExpr { + public: + IntWeightedSum(const std::vector& exprs, + const std::vector& coeffs, int64_t offset) + : exprs_(exprs.begin(), exprs.end()), + coeffs_(coeffs.begin(), coeffs.end()), + offset_(offset) {} + ~IntWeightedSum() override = default; + + void VisitAsInt(IntExprVisitor* lin, int64_t c) override { + for (int i = 0; i < exprs_.size(); ++i) { + lin->AddToProcess(exprs_[i], coeffs_[i] * c); + } + lin->AddConstant(offset_ * c); + } + + void VisitAsFloat(FloatExprVisitor* lin, double c) override { + for (int i = 0; i < exprs_.size(); ++i) { + lin->AddToProcess(exprs_[i], coeffs_[i] * c); + } + lin->AddConstant(offset_ * c); + } + + std::string ToString() const override { + if (exprs_.empty()) { + return absl::StrCat(offset_); + } + std::string s = "("; + bool first_printed = true; + for (int i = 0; i < exprs_.size(); ++i) { + if (coeffs_[i] == 0) continue; + if (first_printed) { + first_printed = false; + if (coeffs_[i] == 1) { + absl::StrAppend(&s, exprs_[i]->ToString()); + } else if (coeffs_[i] == -1) { + absl::StrAppend(&s, "-", exprs_[i]->ToString()); + } else { + absl::StrAppend(&s, coeffs_[i], " * ", exprs_[i]->ToString()); + } + } else { + if (coeffs_[i] == 1) { + absl::StrAppend(&s, " + ", exprs_[i]->ToString()); + } else if (coeffs_[i] == -1) { + absl::StrAppend(&s, " - ", exprs_[i]->ToString()); + } else if (coeffs_[i] > 1) { + absl::StrAppend(&s, " + ", coeffs_[i], " * ", exprs_[i]->ToString()); + } else { + absl::StrAppend(&s, " - ", -coeffs_[i], " * ", exprs_[i]->ToString()); + } + } + } + // If there are no terms, just print the offset. + if (first_printed) { + return absl::StrCat(offset_); + } + + // If there is an offset, print it. + if (offset_ != 0) { + if (offset_ > 0) { + absl::StrAppend(&s, " + ", offset_); + } else { + absl::StrAppend(&s, " - ", -offset_); + } + } + absl::StrAppend(&s, ")"); + return s; + } + + std::string DebugString() const override { + return absl::StrCat( + "IntWeightedSum([", + absl::StrJoin(exprs_, ", ", + [](std::string* out, IntLinExpr* expr) { + absl::StrAppend(out, expr->DebugString()); + }), + "], [", absl::StrJoin(coeffs_, ", "), "], ", offset_, ")"); + } + + private: + const absl::FixedArray exprs_; + const absl::FixedArray coeffs_; + int64_t offset_; +}; + +// A class to hold int_exr * a = b. +class IntAffine : public IntLinExpr { + public: + IntAffine(IntLinExpr* expr, int64_t coeff, int64_t offset) + : expr_(expr), coeff_(coeff), offset_(offset) {} + ~IntAffine() override = default; + + void VisitAsInt(IntExprVisitor* lin, int64_t c) override { + lin->AddToProcess(expr_, c * coeff_); + lin->AddConstant(offset_ * c); + } + + void VisitAsFloat(FloatExprVisitor* lin, double c) override { + lin->AddToProcess(expr_, c * coeff_); + lin->AddConstant(offset_ * c); + } + + std::string ToString() const override { + std::string s = "("; + if (coeff_ == 1) { + absl::StrAppend(&s, expr_->ToString()); + } else if (coeff_ == -1) { + absl::StrAppend(&s, "-", expr_->ToString()); + } else { + absl::StrAppend(&s, coeff_, " * ", expr_->ToString()); + } + if (offset_ > 0) { + absl::StrAppend(&s, " + ", offset_); + } else if (offset_ < 0) { + absl::StrAppend(&s, " - ", -offset_); + } + absl::StrAppend(&s, ")"); + return s; + } + + std::string DebugString() const override { + return absl::StrCat("IntAffine(expr=", expr_->DebugString(), + ", coeff=", coeff_, ", offset=", offset_, ")"); + } + + IntLinExpr* expression() const { return expr_; } + int64_t coefficient() const { return coeff_; } + int64_t offset() const { return offset_; } + + private: + IntLinExpr* expr_; + int64_t coeff_; + int64_t offset_; +}; + +// A Boolean literal (a Boolean variable or its negation). +class Literal { + public: + virtual ~Literal() = default; + virtual int index() const = 0; + virtual Literal* negated() = 0; +}; + +// A class to hold a variable index. +class BaseIntVar : public IntLinExpr, public Literal { + public: + explicit BaseIntVar(int index) + : index_(index), is_boolean_(false), negated_(nullptr) { + DCHECK_GE(index, 0); + } + BaseIntVar(int index, bool is_boolean) + : index_(index), is_boolean_(is_boolean), negated_(nullptr) { + DCHECK_GE(index, 0); + } + + ~BaseIntVar() override = default; + + int index() const override { return index_; } + + void VisitAsInt(IntExprVisitor* lin, int64_t c) override { + lin->AddVarCoeff(this, c); + } + + void VisitAsFloat(FloatExprVisitor* lin, double c) override { + lin->AddVarCoeff(this, c); + } + + std::string ToString() const override { + if (is_boolean_) { + return absl::StrCat("BooleanBaseIntVar(", index_, ")"); + } else { + return absl::StrCat("BaseIntVar(", index_, ")"); + } + } + + std::string DebugString() const override { + return absl::StrCat("BaseIntVar(index=", index_, + ", is_boolean=", is_boolean_, ")"); + } + + Literal* negated() override; + + bool is_boolean() const { return is_boolean_; } + + bool operator<(const BaseIntVar& other) const { + return index_ < other.index_; + } + + protected: + const int index_; + bool is_boolean_; + Literal* negated_; +}; + +// A class to hold a negated variable index. +class NotBooleanVariable : public IntLinExpr, public Literal { + public: + explicit NotBooleanVariable(BaseIntVar* var) : var_(var) {} + ~NotBooleanVariable() override = default; + + int index() const override { return -var_->index() - 1; } + + void VisitAsInt(IntExprVisitor* lin, int64_t c) override { + lin->AddVarCoeff(var_, -c); + lin->AddConstant(c); + } + + void VisitAsFloat(FloatExprVisitor* lin, double c) override { + lin->AddVarCoeff(var_, -c); + lin->AddConstant(c); + } + + std::string ToString() const override { + return absl::StrCat("not(", var_->ToString(), ")"); + } + + Literal* negated() override { return var_; } + + std::string DebugString() const override { + return absl::StrCat("NotBooleanVariable(index=", var_->index(), ")"); + } + + private: + BaseIntVar* var_; +}; + +inline Literal* BaseIntVar::negated() { + if (negated_ == nullptr) { + negated_ = new NotBooleanVariable(this); + } + return negated_; +} + +} // namespace python +} // namespace sat +} // namespace operations_research + +#endif // OR_TOOLS_SAT_PYTHON_LINEAR_EXPR_H_ diff --git a/ortools/sat/python/swig_helper.cc b/ortools/sat/python/swig_helper.cc index 5f225cb9543..d0f784b6254 100644 --- a/ortools/sat/python/swig_helper.cc +++ b/ortools/sat/python/swig_helper.cc @@ -12,21 +12,18 @@ // limitations under the License. // This file wraps the swig_helper.h classes in python using pybind11. -// Because pybind11_protobuf does not support building with CMake for OR-Tools, -// the API has been transformed to use serialized protos from Python to C++ and -// from C++ to python: -// from Python to C++: use proto.SerializeToString(). This creates a python -// string that is passed to C++ and parsed back to proto. -// from C++ to Python, we cast the result of proto.SerializeAsString() to -// pybind11::bytes. This is passed back to python, which will reconstruct -// the proto using PythonProto.FromString(byte[]). - #include "ortools/sat/swig_helper.h" +#include + +#include +#include #include +#include -#include "absl/strings/string_view.h" +#include "absl/strings/str_cat.h" #include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/python/linear_expr.h" #include "ortools/util/sorted_interval_list.h" #include "pybind11/cast.h" #include "pybind11/functional.h" @@ -35,22 +32,20 @@ #include "pybind11/stl.h" #include "pybind11_protobuf/native_proto_caster.h" -using ::operations_research::Domain; -using ::operations_research::sat::CpModelProto; -using ::operations_research::sat::CpSatHelper; -using ::operations_research::sat::CpSolverResponse; -using ::operations_research::sat::IntegerVariableProto; -using ::operations_research::sat::SatParameters; -using ::operations_research::sat::SolutionCallback; -using ::operations_research::sat::SolveWrapper; -using ::pybind11::arg; +namespace py = pybind11; + +namespace operations_research { +namespace sat { +namespace python { + +using ::py::arg; class PySolutionCallback : public SolutionCallback { public: using SolutionCallback::SolutionCallback; /* Inherit constructors */ void OnSolutionCallback() const override { - ::pybind11::gil_scoped_acquire acquire; + ::py::gil_scoped_acquire acquire; PYBIND11_OVERRIDE_PURE( void, /* Return type */ SolutionCallback, /* Parent class */ @@ -61,12 +56,141 @@ class PySolutionCallback : public SolutionCallback { } }; +// A trampoline class to override the __str__ and __repr__ methods. +class PyBaseIntVar : public BaseIntVar { + public: + using BaseIntVar::BaseIntVar; /* Inherit constructors */ + + std::string ToString() const override { + PYBIND11_OVERRIDE_NAME(std::string, // Return type (ret_type) + BaseIntVar, // Parent class (cname) + "__str__", // Name of method in Python (name) + ToString, // Name of function in C++ (fn) + ); + } + + std::string DebugString() const override { + PYBIND11_OVERRIDE_NAME(std::string, // Return type (ret_type) + BaseIntVar, // Parent class (cname) + "__repr__", // Name of method in Python (name) + DebugString, // Name of function in C++ (fn) + ); + } +}; + +// A class to wrap a C++ CpSolverResponse in a Python object, avoid the proto +// conversion back to python. +class ResponseWrapper { + public: + explicit ResponseWrapper(const CpSolverResponse& response) + : response_(response) {} + + double BestObjectiveBound() const { return response_.best_objective_bound(); } + + bool BooleanValue(Literal* lit) const { + const int index = lit->index(); + if (index >= 0) { + return response_.solution(index) != 0; + } else { + return response_.solution(-index - 1) == 0; + } + } + + bool FixedBooleanValue(bool lit) const { return lit; } + + double DeterministicTime() const { return response_.deterministic_time(); } + + int64_t NumBinaryPropagations() const { + return response_.num_binary_propagations(); + } + + int64_t NumBooleans() const { return response_.num_booleans(); } + + int64_t NumBranches() const { return response_.num_branches(); } + + int64_t NumConflicts() const { return response_.num_conflicts(); } + + int64_t NumIntegerPropagations() const { + return response_.num_integer_propagations(); + } + + int64_t NumRestarts() const { return response_.num_restarts(); } + + double ObjectiveValue() const { return response_.objective_value(); } + + const CpSolverResponse& Response() const { return response_; } + + std::string ResponseStats() const { + return CpSatHelper::SolverResponseStats(response_); + } + + std::string SolutionInfo() const { return response_.solution_info(); } + + std::vector SufficientAssumptionsForInfeasibility() const { + return std::vector( + response_.sufficient_assumptions_for_infeasibility().begin(), + response_.sufficient_assumptions_for_infeasibility().end()); + } + + CpSolverStatus Status() const { return response_.status(); } + + double UserTime() const { return response_.user_time(); } + + int64_t Value(IntLinExpr* expr) const { + IntExprVisitor visitor; + return visitor.Evaluate(expr, response_); + } + + int64_t FixedValue(int64_t value) const { return value; } + + double WallTime() const { return response_.wall_time(); } + + private: + const CpSolverResponse response_; +}; + +void throw_error(PyObject* py_exception, const std::string& message) { + PyErr_SetString(py_exception, message.c_str()); + throw py::error_already_set(); +} + +const char* kIntLinExprClassDoc = R"doc( + Holds an integer linear expression. + + A linear expression is built from integer constants and variables. + For example, `x + 2 * (y - z + 1)`. + + Linear expressions are used in CP-SAT models in constraints and in the + objective: + + * You can define linear constraints as in: + + ``` + model.add(x + 2 * y <= 5) + model.add(sum(array_of_vars) == 5) + ``` + + * In CP-SAT, the objective is a linear expression: + + ``` + model.minimize(x + 2 * y + z) + ``` + + * For large arrays, using the LinearExpr class is faster that using the python + `sum()` function. You can create constraints and the objective from lists of + linear expressions or coefficients as follows: + + ``` + model.minimize(cp_model.LinearExpr.sum(expressions)) + model.add(cp_model.LinearExpr.weighted_sum(expressions, coefficients) >= 0) + ```)doc"; + PYBIND11_MODULE(swig_helper, m) { pybind11_protobuf::ImportNativeProtoCasters(); - pybind11::module::import("ortools.util.python.sorted_interval_list"); + py::module::import("ortools.util.python.sorted_interval_list"); - pybind11::class_(m, "SolutionCallback") - .def(pybind11::init<>()) + py::class_(m, "SolutionCallback") + .def(py::init<>()) .def("OnSolutionCallback", &SolutionCallback::OnSolutionCallback) .def("BestObjectiveBound", &SolutionCallback::BestObjectiveBound) .def("DeterministicTime", &SolutionCallback::DeterministicTime) @@ -84,10 +208,58 @@ PYBIND11_MODULE(swig_helper, m) { arg("index")) .def("StopSearch", &SolutionCallback::StopSearch) .def("UserTime", &SolutionCallback::UserTime) - .def("WallTime", &SolutionCallback::WallTime); + .def("WallTime", &SolutionCallback::WallTime) + .def( + "Value", + [](const SolutionCallback& callback, IntLinExpr* expr) { + IntExprVisitor visitor; + return visitor.Evaluate(expr, callback.Response()); + }, + "Returns the value of a linear expression after solve.") + .def( + "Value", [](const SolutionCallback&, int64_t value) { return value; }, + "Returns the value of a linear expression after solve.") + .def( + "BooleanValue", + [](const SolutionCallback& callback, Literal* lit) { + const int index = lit->index(); + if (index >= 0) { + return callback.Response().solution(index) != 0; + } else { + return callback.Response().solution(-index - 1) == 0; + } + }, + "Returns the boolean value of a literal after solve.") + .def( + "BooleanValue", [](const SolutionCallback&, bool lit) { return lit; }, + "Returns the boolean value of a literal after solve."); - pybind11::class_(m, "SolveWrapper") - .def(pybind11::init<>()) + py::class_(m, "ResponseWrapper") + .def(py::init()) + .def("best_objective_bound", &ResponseWrapper::BestObjectiveBound) + .def("boolean_value", &ResponseWrapper::BooleanValue, arg("lit")) + .def("boolean_value", &ResponseWrapper::FixedBooleanValue, arg("lit")) + .def("deterministic_time", &ResponseWrapper::DeterministicTime) + .def("num_binary_propagations", &ResponseWrapper::NumBinaryPropagations) + .def("num_booleans", &ResponseWrapper::NumBooleans) + .def("num_branches", &ResponseWrapper::NumBranches) + .def("num_conflicts", &ResponseWrapper::NumConflicts) + .def("num_integer_propagations", &ResponseWrapper::NumIntegerPropagations) + .def("num_restarts", &ResponseWrapper::NumRestarts) + .def("objective_value", &ResponseWrapper::ObjectiveValue) + .def("response", &ResponseWrapper::Response) + .def("response_stats", &ResponseWrapper::ResponseStats) + .def("solution_info", &ResponseWrapper::SolutionInfo) + .def("status", &ResponseWrapper::Status) + .def("sufficient_assumptions_for_infeasibility", + &ResponseWrapper::SufficientAssumptionsForInfeasibility) + .def("user_time", &ResponseWrapper::UserTime) + .def("value", &ResponseWrapper::Value, arg("expr")) + .def("value", &ResponseWrapper::FixedValue, arg("value")) + .def("wall_time", &ResponseWrapper::WallTime); + + py::class_(m, "SolveWrapper") + .def(py::init<>()) .def("add_log_callback", &SolveWrapper::AddLogCallback, arg("log_callback")) .def("add_solution_callback", &SolveWrapper::AddSolutionCallback, @@ -98,13 +270,13 @@ PYBIND11_MODULE(swig_helper, m) { .def("set_parameters", &SolveWrapper::SetParameters, arg("parameters")) .def("solve", [](SolveWrapper* solve_wrapper, - const CpModelProto& model_proto) -> CpSolverResponse { - ::pybind11::gil_scoped_release release; - return solve_wrapper->Solve(model_proto); + const CpModelProto& model_proto) -> ResponseWrapper { + ::py::gil_scoped_release release; + return ResponseWrapper(solve_wrapper->Solve(model_proto)); }) .def("stop_search", &SolveWrapper::StopSearch); - pybind11::class_(m, "CpSatHelper") + py::class_(m, "CpSatHelper") .def_static("model_stats", &CpSatHelper::ModelStats, arg("model_proto")) .def_static("solver_response_stats", &CpSatHelper::SolverResponseStats, arg("response")) @@ -114,4 +286,494 @@ PYBIND11_MODULE(swig_helper, m) { arg("variable_proto")) .def_static("write_model_to_file", &CpSatHelper::WriteModelToFile, arg("model_proto"), arg("filename")); -} + + py::class_(m, "FloatExprOrValue") + .def(py::init()) + .def(py::init()) + .def_readonly("value", &FloatExprOrValue::value) + .def_readonly("expr", &FloatExprOrValue::expr); + + py::implicitly_convertible(); + py::implicitly_convertible(); + + py::class_(m, "FloatLinearExpr") + .def(py::init<>()) + .def_static("sum", + py::overload_cast&>( + &FloatLinearExpr::Sum), + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def_static( + "sum", + py::overload_cast&, double>( + &FloatLinearExpr::Sum), + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def_static("weighted_sum", + py::overload_cast&, + const std::vector&>( + &FloatLinearExpr::WeightedSum), + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def_static("weighted_sum", + py::overload_cast&, + const std::vector&, double>( + &FloatLinearExpr::WeightedSum), + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def_static("WeightedSum", + py::overload_cast&, + const std::vector&>( + &FloatLinearExpr::WeightedSum), + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def_static("term", &FloatLinearExpr::Term, arg("expr"), arg("coeff"), + "Returns expr * coeff.", py::return_value_policy::automatic, + py::keep_alive<0, 1>()) + .def_static("affine", &FloatLinearExpr::Affine, arg("expr"), arg("coeff"), + arg("offset"), "Returns expr * coeff + offset.", + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def_static("constant", FloatLinearExpr::Constant, arg("value"), + "Returns a constant linear expression.", + py::return_value_policy::automatic) + .def("__str__", &FloatLinearExpr::ToString) + .def("__repr__", &FloatLinearExpr::DebugString) + .def("is_integer", &FloatLinearExpr::is_integer) + .def("__add__", &FloatLinearExpr::FloatAddCst, + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__add__", &FloatLinearExpr::FloatAdd, + py::return_value_policy::automatic, py::keep_alive<0, 1>(), + py::keep_alive<0, 2>()) + .def("__radd__", &FloatLinearExpr::FloatAddCst, + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__radd__", &FloatLinearExpr::FloatAdd, + py::return_value_policy::automatic, py::keep_alive<0, 1>(), + py::keep_alive<0, 2>()) + .def("__sub__", &FloatLinearExpr::FloatSub, + py::return_value_policy::automatic, py::keep_alive<0, 1>(), + py::keep_alive<0, 2>()) + .def("__sub__", &FloatLinearExpr::FloatSubCst, + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__rsub__", &FloatLinearExpr::FloatRSub, + py::return_value_policy::automatic, py::keep_alive<0, 1>(), + py::keep_alive<0, 2>()) + .def("__rsub__", &FloatLinearExpr::FloatRSubCst, + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__mul__", &FloatLinearExpr::FloatMulCst, + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__rmul__", &FloatLinearExpr::FloatMulCst, + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__neg__", &FloatLinearExpr::FloatNeg, + py::return_value_policy::automatic, py::keep_alive<0, 1>()); + + py::class_(m, "FloatAffine") + .def(py::init()) + .def_property_readonly("expression", &FloatAffine::expression) + .def_property_readonly("coefficient", &FloatAffine::coefficient) + .def_property_readonly("offset", &FloatAffine::offset); + + py::class_(m, "CanonicalFloatExpression") + .def(py::init()) + .def_property_readonly("vars", &CanonicalFloatExpression::vars) + .def_property_readonly("coeffs", &CanonicalFloatExpression::coeffs) + .def_property_readonly("offset", &CanonicalFloatExpression::offset); + + py::class_(m, "IntExprOrValue") + .def(py::init()) + .def(py::init()) + .def_readonly("value", &IntExprOrValue::value) + .def_readonly("expr", &IntExprOrValue::expr); + + py::implicitly_convertible(); + py::implicitly_convertible(); + + py::class_(m, "LinearExpr", kIntLinExprClassDoc) + .def(py::init<>()) + .def_static("sum", + py::overload_cast&>( + &IntLinExpr::Sum), + "Returns sum(exprs)", py::return_value_policy::automatic, + py::keep_alive<0, 1>()) + .def_static( + "sum", + py::overload_cast&, int64_t>( + &IntLinExpr::Sum), + "Returns sum(exprs) + cst", py::return_value_policy::automatic, + py::keep_alive<0, 1>()) + .def_static("weighted_sum", + py::overload_cast&, + const std::vector&>( + &IntLinExpr::WeightedSum), + "Returns sum(exprs[i] * coeffs[i]", + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def_static("weighted_sum", + py::overload_cast&, + const std::vector&, int64_t>( + &IntLinExpr::WeightedSum), + "Returns sum(exprs[i] * coeffs[i]) + cst", + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def_static("weighted_sum", + py::overload_cast&, + const std::vector&>( + &FloatLinearExpr::WeightedSum), + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def_static("term", &IntLinExpr::Term, arg("expr"), arg("coeff"), + "Returns expr * coeff.", py::return_value_policy::automatic, + py::keep_alive<0, 1>()) + .def_static("affine", &IntLinExpr::Affine, arg("expr"), arg("coeff"), + arg("offset"), "Returns expr * coeff.", + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def_static("constant", IntLinExpr::Constant, arg("value"), + "Returns a constant linear expression.", + py::return_value_policy::automatic) + .def("is_integer", &IntLinExpr::is_integer) + .def("__add__", &IntLinExpr::IntAddCst, + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__add__", &FloatLinearExpr::FloatAddCst, + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__add__", &IntLinExpr::IntAdd, py::return_value_policy::automatic, + py::keep_alive<0, 1>(), py::keep_alive<0, 2>()) + .def("__add__", &FloatLinearExpr::FloatAdd, + py::return_value_policy::automatic, py::keep_alive<0, 1>(), + py::keep_alive<0, 2>()) + .def("__radd__", &IntLinExpr::IntAddCst, + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__radd__", &FloatLinearExpr::FloatAddCst, + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__radd__", &FloatLinearExpr::FloatAdd, + py::return_value_policy::automatic, py::keep_alive<0, 1>(), + py::keep_alive<0, 2>()) + .def("__sub__", &IntLinExpr::IntSubCst, + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__sub__", &FloatLinearExpr::FloatSubCst, + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__sub__", &FloatLinearExpr::FloatSubCst, + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__sub__", &IntLinExpr::IntSub, py::return_value_policy::automatic, + py::keep_alive<0, 1>(), py::keep_alive<0, 2>()) + .def("__sub__", &FloatLinearExpr::FloatSub, + py::return_value_policy::automatic, py::keep_alive<0, 1>(), + py::keep_alive<0, 2>()) + .def("__rsub__", &IntLinExpr::IntRSubCst, + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__rsub__", &FloatLinearExpr::FloatRSubCst, + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__rsub__", &FloatLinearExpr::FloatRSub, + py::return_value_policy::automatic, py::keep_alive<0, 1>(), + py::keep_alive<0, 2>()) + .def("__mul__", &IntLinExpr::IntMulCst, + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__rmul__", &IntLinExpr::IntMulCst, + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__neg__", &IntLinExpr::IntNeg, py::return_value_policy::automatic, + py::keep_alive<0, 1>()) + .def("__mul__", &FloatLinearExpr::FloatMulCst, + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__rmul__", &FloatLinearExpr::FloatMulCst, + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__eq__", &IntLinExpr::Eq, py::return_value_policy::automatic, + py::keep_alive<0, 1>(), py::keep_alive<0, 2>()) + .def("__eq__", &IntLinExpr::EqCst, py::return_value_policy::automatic, + py::keep_alive<0, 1>()) + .def("__ne__", &IntLinExpr::Ne, py::return_value_policy::automatic, + py::keep_alive<0, 1>(), py::keep_alive<0, 2>()) + .def("__ne__", &IntLinExpr::NeCst, py::return_value_policy::automatic, + py::keep_alive<0, 1>()) + .def("__lt__", &IntLinExpr::Lt, py::return_value_policy::automatic, + py::keep_alive<0, 1>(), py::keep_alive<0, 2>()) + .def( + "__lt__", + [](IntLinExpr* expr, int64_t bound) { + if (bound == std::numeric_limits::min()) { + throw_error(PyExc_ArithmeticError, "< INT_MIN is not supported"); + } + return expr->LtCst(bound); + }, + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__le__", &IntLinExpr::Le, py::return_value_policy::automatic, + py::keep_alive<0, 1>(), py::keep_alive<0, 2>()) + .def( + "__le__", + [](IntLinExpr* expr, int64_t bound) { + if (bound == std::numeric_limits::min()) { + throw_error(PyExc_ArithmeticError, "<= INT_MIN is not supported"); + } + return expr->LeCst(bound); + }, + py::return_value_policy::automatic, + + py::keep_alive<0, 1>()) + .def("__gt__", &IntLinExpr::Gt, py::return_value_policy::automatic, + py::keep_alive<0, 1>(), py::keep_alive<0, 2>()) + .def( + "__gt__", + [](IntLinExpr* expr, int64_t bound) { + if (bound == std::numeric_limits::max()) { + throw_error(PyExc_ArithmeticError, "> INT_MAX is not supported"); + } + return expr->GtCst(bound); + }, + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__ge__", &IntLinExpr::Ge, py::return_value_policy::automatic, + py::keep_alive<0, 1>(), py::keep_alive<0, 2>()) + .def( + "__ge__", + [](IntLinExpr* expr, int64_t bound) { + if (bound == std::numeric_limits::max()) { + throw_error(PyExc_ArithmeticError, ">= INT_MAX is not supported"); + } + return expr->GeCst(bound); + }, + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + + .def("__div__", + [](IntLinExpr* /*self*/, IntLinExpr* /*other*/) { + throw_error(PyExc_NotImplementedError, + "calling / on a linear expression is not supported, " + "please use CpModel.add_division_equality"); + }) + .def("__div__", + [](IntLinExpr* /*self*/, int64_t /*cst*/) { + throw_error(PyExc_NotImplementedError, + "calling / on a linear expression is not supported, " + "please use CpModel.add_division_equality"); + }) + .def("__truediv__", + [](IntLinExpr* /*self*/, IntLinExpr* /*other*/) { + throw_error(PyExc_NotImplementedError, + "calling // on a linear expression is not supported, " + "please use CpModel.add_division_equality"); + }) + .def("__truediv__", + [](IntLinExpr* /*self*/, int64_t /*cst*/) { + throw_error(PyExc_NotImplementedError, + "calling // on a linear expression is not supported, " + "please use CpModel.add_division_equality"); + }) + .def("__mod__", + [](IntLinExpr* /*self*/, IntLinExpr* /*other*/) { + throw_error(PyExc_NotImplementedError, + "calling %% on a linear expression is not supported, " + "please use CpModel.add_modulo_equality"); + }) + .def("__mod__", + [](IntLinExpr* /*self*/, int64_t /*cst*/) { + throw_error(PyExc_NotImplementedError, + "calling %% on a linear expression is not supported, " + "please use CpModel.add_modulo_equality"); + }) + .def("__pow__", + [](IntLinExpr* /*self*/, IntLinExpr* /*other*/) { + throw_error(PyExc_NotImplementedError, + "calling ** on a linear expression is not supported, " + "please use CpModel.add_multiplication_equality"); + }) + .def("__pow__", + [](IntLinExpr* /*self*/, int64_t /*cst*/) { + throw_error(PyExc_NotImplementedError, + "calling ** on a linear expression is not supported, " + "please use CpModel.add_multiplication_equality"); + }) + .def("__lshift__", + [](IntLinExpr* /*self*/, IntLinExpr* /*other*/) { + throw_error( + PyExc_NotImplementedError, + "calling left shift on a linear expression is not supported"); + }) + .def("__lshift__", + [](IntLinExpr* /*self*/, int64_t /*cst*/) { + throw_error( + PyExc_NotImplementedError, + "calling left shift on a linear expression is not supported"); + }) + .def("__rshift__", + [](IntLinExpr* /*self*/, IntLinExpr* /*other*/) { + throw_error( + PyExc_NotImplementedError, + "calling right shift on a linear expression is not supported"); + }) + .def("__rshift__", + [](IntLinExpr* /*self*/, int64_t /*cst*/) { + throw_error( + PyExc_NotImplementedError, + "calling right shift on a linear expression is not supported"); + }) + .def("__and__", + [](IntLinExpr* /*self*/, IntLinExpr* /*other*/) { + throw_error(PyExc_NotImplementedError, + "calling and on a linear expression is not supported"); + }) + .def("__and__", + [](IntLinExpr* /*self*/, int64_t /*cst*/) { + throw_error(PyExc_NotImplementedError, + "calling and on a linear expression is not supported"); + }) + .def("__or__", + [](IntLinExpr* /*self*/, IntLinExpr* /*other*/) { + throw_error(PyExc_NotImplementedError, + "calling or on a linear expression is not supported"); + }) + .def("__or__", + [](IntLinExpr* /*self*/, int64_t /*cst*/) { + throw_error(PyExc_NotImplementedError, + "calling or on a linear expression is not supported"); + }) + .def("__xor__", + [](IntLinExpr* /*self*/, IntLinExpr* /*other*/) { + throw_error(PyExc_NotImplementedError, + "calling xor on a linear expression is not supported"); + }) + .def("__xor__", + [](IntLinExpr* /*self*/, int64_t /*cst*/) { + throw_error(PyExc_NotImplementedError, + "calling xor on a linear expression is not supported"); + }) + .def("__abs__", + [](IntLinExpr* /*self*/) { + throw_error( + PyExc_NotImplementedError, + "calling abs() on a linear expression is not supported, " + "please use CpModel.add_abs_equality"); + }) + .def("__bool__", + [](IntLinExpr* /*self*/) { + throw_error(PyExc_NotImplementedError, + "Evaluating a LinearExpr instance as a Boolean is " + "not implemented."); + }) + .def_static("Sum", + py::overload_cast&>( + &IntLinExpr::Sum), + "Returns sum(exprs)", py::return_value_policy::automatic, + py::keep_alive<0, 1>()) + .def_static( + "Sum", + py::overload_cast&, int64_t>( + &IntLinExpr::Sum), + "Returns sum(exprs) + cst", py::return_value_policy::automatic, + py::keep_alive<0, 1>()) + .def_static("WeightedSum", + py::overload_cast&, + const std::vector&>( + &IntLinExpr::WeightedSum), + "Returns sum(exprs[i] * coeffs[i]", + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def_static("WeightedSum", + py::overload_cast&, + const std::vector&>( + &FloatLinearExpr::WeightedSum), + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def_static("Term", &IntLinExpr::Term, arg("expr"), arg("coeff"), + "Returns expr * coeff.", py::return_value_policy::automatic); + + py::class_(m, "IntAffine") + .def(py::init()) + .def_property_readonly("expression", &IntAffine::expression) + .def_property_readonly("coefficient", &IntAffine::coefficient) + .def_property_readonly("offset", &IntAffine::offset); + + py::class_(m, "Literal") + .def_property_readonly("index", &Literal::index, + "The index of the variable in the model.") + .def("negated", &Literal::negated, + R"doc( + Returns the negation of a literal (a Boolean variable or its negation). + + This method implements the logical negation of a Boolean variable. + It is only valid if the variable has a Boolean domain (0 or 1). + + Note that this method is nilpotent: `x.negated().negated() == x`. + )doc", + py::return_value_policy::automatic, py::keep_alive<1, 0>()) + .def("__invert__", &Literal::negated, + "Returns the negation of the current literal.", + py::return_value_policy::automatic) + .def("__bool__", + [](Literal* /*self*/) { + throw_error(PyExc_NotImplementedError, + "Evaluating a Literal instance as a Boolean is " + "not implemented."); + }) + // PEP8 Compatibility. + .def("Not", &Literal::negated, py::return_value_policy::automatic) + .def("Index", &Literal::index); + + py::class_(m, "BaseIntVar") + .def(py::init()) + .def(py::init()) + .def_property_readonly("index", &BaseIntVar::index, + "The index of the variable in the model.") + .def_property_readonly("is_boolean", &BaseIntVar::is_boolean, + "Whether the variable is boolean.") + .def("__str__", &BaseIntVar::ToString) + .def("__repr__", &BaseIntVar::DebugString) + .def( + "negated", + [](BaseIntVar* self) { + if (!self->is_boolean()) { + throw_error(PyExc_TypeError, + "negated() is only supported for boolean variables."); + } + return self->negated(); + }, + "Returns the negation of the current Boolean variable.", + py::return_value_policy::automatic, py::keep_alive<1, 0>()) + .def( + "__invert__", + [](BaseIntVar* self) { + if (!self->is_boolean()) { + throw_error(PyExc_ValueError, + "negated() is only supported for boolean variables."); + } + return self->negated(); + }, + "Returns the negation of the current Boolean variable.", + py::return_value_policy::automatic, py::keep_alive<1, 0>()) + // PEP8 Compatibility. + .def( + "Not", + [](BaseIntVar* self) { + if (!self->is_boolean()) { + throw_error(PyExc_ValueError, + "negated() is only supported for boolean variables."); + } + return self->negated(); + }, + py::return_value_policy::automatic, py::keep_alive<1, 0>()); + + py::class_(m, "NotBooleanVariable") + .def(py::init()) + .def_property_readonly("index", &NotBooleanVariable::index, + "The index of the variable in the model.") + .def("__str__", &NotBooleanVariable::ToString) + .def("__repr__", &NotBooleanVariable::DebugString) + .def("negated", &NotBooleanVariable::negated, + "Returns the negation of the current Boolean variable.", + py::return_value_policy::automatic) + .def("__invert__", &NotBooleanVariable::negated, + "Returns the negation of the current Boolean variable.", + py::return_value_policy::automatic) + .def("Not", &NotBooleanVariable::negated, + "Returns the negation of the current Boolean variable.", + py::return_value_policy::automatic); + + py::class_(m, "BoundedLinearExpression") + .def(py::init()) + .def(py::init()) + .def_property_readonly("bounds", &BoundedLinearExpression::bounds) + .def_property_readonly("vars", &BoundedLinearExpression::vars) + .def_property_readonly("coeffs", &BoundedLinearExpression::coeffs) + .def_property_readonly("offset", &BoundedLinearExpression::offset) + .def("__str__", &BoundedLinearExpression::ToString) + .def("__repr__", &BoundedLinearExpression::DebugString) + .def("__bool__", [](const BoundedLinearExpression& self) { + bool result; + if (self.CastToBool(&result)) return result; + throw_error(PyExc_NotImplementedError, + absl::StrCat("Evaluating a BoundedLinearExpression '", + self.ToString(), + "'instance as a Boolean is " + "not implemented.") + .c_str()); + return false; + }); +} // NOLINT(readability/fn_size) + +} // namespace python +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/python/swig_helper_test.py b/ortools/sat/python/swig_helper_test.py index a332ecb53a0..bc06ce98096 100644 --- a/ortools/sat/python/swig_helper_test.py +++ b/ortools/sat/python/swig_helper_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for ortools.sat.python.swig_helper.""" +"""Unit tests for ortools.sat.python.swig_helper.""" from absl.testing import absltest from google.protobuf import text_format @@ -44,6 +44,19 @@ def new_best_bound(self, bb: float): self.best_bound = bb +class TestIntVar(swig_helper.BaseIntVar): + + def __init__(self, index: int, name: str, is_boolean: bool = False) -> None: + swig_helper.BaseIntVar.__init__(self, index, is_boolean) + self._name = name + + def __str__(self) -> str: + return self._name + + def __repr__(self) -> str: + return self._name + + class SwigHelperTest(absltest.TestCase): def testVariableDomain(self): @@ -96,10 +109,10 @@ def testSimpleSolve(self): self.assertTrue(text_format.Parse(model_string, model)) solve_wrapper = swig_helper.SolveWrapper() - solution = solve_wrapper.solve(model) + response_wrapper = solve_wrapper.solve(model) - self.assertEqual(cp_model_pb2.OPTIMAL, solution.status) - self.assertEqual(30.0, solution.objective_value) + self.assertEqual(cp_model_pb2.OPTIMAL, response_wrapper.status()) + self.assertEqual(30.0, response_wrapper.objective_value()) def testSimpleSolveWithCore(self): model_string = """ @@ -140,10 +153,10 @@ def testSimpleSolveWithCore(self): solve_wrapper = swig_helper.SolveWrapper() solve_wrapper.set_parameters(parameters) - solution = solve_wrapper.solve(model) + response_wrapper = solve_wrapper.solve(model) - self.assertEqual(cp_model_pb2.OPTIMAL, solution.status) - self.assertEqual(30.0, solution.objective_value) + self.assertEqual(cp_model_pb2.OPTIMAL, response_wrapper.status()) + self.assertEqual(30.0, response_wrapper.objective_value()) def testSimpleSolveWithProtoApi(self): model = cp_model_pb2.CpModelProto() @@ -162,11 +175,11 @@ def testSimpleSolveWithProtoApi(self): model.objective.scaling_factor = -1 solve_wrapper = swig_helper.SolveWrapper() - solution = solve_wrapper.solve(model) + response_wrapper = solve_wrapper.solve(model) - self.assertEqual(cp_model_pb2.OPTIMAL, solution.status) - self.assertEqual(30.0, solution.objective_value) - self.assertEqual(30.0, solution.best_objective_bound) + self.assertEqual(cp_model_pb2.OPTIMAL, response_wrapper.status()) + self.assertEqual(30.0, response_wrapper.objective_value()) + self.assertEqual(30.0, response_wrapper.best_objective_bound()) def testSolutionCallback(self): model_string = """ @@ -184,10 +197,10 @@ def testSolutionCallback(self): params = sat_parameters_pb2.SatParameters() params.enumerate_all_solutions = True solve_wrapper.set_parameters(params) - solution = solve_wrapper.solve(model) + response_wrapper = solve_wrapper.solve(model) self.assertEqual(5, callback.solution_count()) - self.assertEqual(cp_model_pb2.OPTIMAL, solution.status) + self.assertEqual(cp_model_pb2.OPTIMAL, response_wrapper.status()) def testBestBoundCallback(self): model_string = """ @@ -213,10 +226,10 @@ def testBestBoundCallback(self): params.linearization_level = 2 params.log_search_progress = True solve_wrapper.set_parameters(params) - solution = solve_wrapper.solve(model) + response_wrapper = solve_wrapper.solve(model) self.assertEqual(2.6, best_bound_callback.best_bound) - self.assertEqual(cp_model_pb2.OPTIMAL, solution.status) + self.assertEqual(cp_model_pb2.OPTIMAL, response_wrapper.status()) def testModelStats(self): model_string = """ @@ -257,6 +270,94 @@ def testModelStats(self): stats = swig_helper.CpSatHelper.model_stats(model) self.assertTrue(stats) + def testIntLinExpr(self): + x = TestIntVar(0, "x") + self.assertTrue(x.is_integer()) + self.assertIsInstance(x, swig_helper.BaseIntVar) + self.assertIsInstance(x, swig_helper.LinearExpr) + e1 = x + 2 + self.assertTrue(e1.is_integer()) + self.assertEqual(str(e1), "(x + 2)") + e2 = 3 + x + self.assertTrue(e2.is_integer()) + self.assertEqual(str(e2), "(x + 3)") + y = TestIntVar(1, "y") + e3 = y * 5 + self.assertTrue(e3.is_integer()) + self.assertEqual(str(e3), "(5 * y)") + e4 = -2 * y + self.assertTrue(e4.is_integer()) + self.assertEqual(str(e4), "(-2 * y)") + e5 = x - 1 + self.assertTrue(e5.is_integer()) + self.assertEqual(str(e5), "(x - 1)") + e6 = x - 2 * y + self.assertTrue(e6.is_integer()) + self.assertEqual(str(e6), "(x - (2 * y))") + z = TestIntVar(2, "z", True) + e7 = -z + self.assertTrue(e7.is_integer()) + self.assertEqual(str(e7), "(-z)") + not_z = ~z + self.assertTrue(not_z.is_integer()) + self.assertEqual(str(not_z), "not(z)") + self.assertEqual(not_z.index, -3) + + e8 = swig_helper.LinearExpr.sum([x, y, z]) + self.assertEqual(str(e8), "(x + y + z)") + e9 = swig_helper.LinearExpr.sum([x, y, z], 11) + self.assertEqual(str(e9), "(x + y + z + 11)") + e10 = swig_helper.LinearExpr.weighted_sum([x, y, z], [1, 2, 3]) + self.assertEqual(str(e10), "(x + 2 * y + 3 * z)") + e11 = swig_helper.LinearExpr.weighted_sum([x, y, z], [1, 2, 3], -5) + self.assertEqual(str(e11), "(x + 2 * y + 3 * z - 5)") + + def testFloatLinExpr(self): + x = TestIntVar(0, "x") + self.assertTrue(x.is_integer()) + self.assertIsInstance(x, TestIntVar) + self.assertIsInstance(x, swig_helper.LinearExpr) + self.assertIsInstance(x, swig_helper.FloatLinearExpr) + e1 = x + 2.5 + self.assertFalse(e1.is_integer()) + self.assertEqual(str(e1), "(x + 2.5)") + e2 = 3.1 + x + self.assertFalse(e2.is_integer()) + self.assertEqual(str(e2), "(x + 3.1)") + y = TestIntVar(1, "y") + e3 = y * 5.2 + self.assertFalse(e3.is_integer()) + self.assertEqual(str(e3), "(5.2 * y)") + e4 = -2.2 * y + self.assertFalse(e4.is_integer()) + self.assertEqual(str(e4), "(-2.2 * y)") + e5 = x - 1.1 + self.assertFalse(e5.is_integer()) + self.assertEqual(str(e5), "(x - 1.1)") + e6 = x + 2.4 * y + self.assertFalse(e6.is_integer()) + self.assertEqual(str(e6), "(x + (2.4 * y))") + e7 = x - 2.4 * y + self.assertFalse(e7.is_integer()) + self.assertEqual(str(e7), "(x - (2.4 * y))") + + z = TestIntVar(2, "z") + e8 = swig_helper.FloatLinearExpr.sum([x, y, z]) + self.assertFalse(e8.is_integer()) + self.assertEqual(str(e8), "(x + y + z)") + e9 = swig_helper.FloatLinearExpr.sum([x, y, z], 1.5) + self.assertFalse(e9.is_integer()) + self.assertEqual(str(e9), "(x + y + z + 1.5)") + e10 = swig_helper.FloatLinearExpr.weighted_sum([x, y, z], [1.0, 2.2, 3.3]) + self.assertFalse(e10.is_integer()) + self.assertEqual(str(e10), "(x + 2.2 * y + 3.3 * z)") + e11 = swig_helper.FloatLinearExpr.weighted_sum([x, y, z], [1.0, 2.2, 3.3], 1.5) + self.assertFalse(e11.is_integer()) + self.assertEqual(str(e11), "(x + 2.2 * y + 3.3 * z + 1.5)") + e12 = (x + 2) * 3.1 + self.assertFalse(e12.is_integer()) + self.assertEqual(str(e12), "(3.1 * (x + 2))") + if __name__ == "__main__": absltest.main() diff --git a/ortools/sat/swig_helper.cc b/ortools/sat/swig_helper.cc index 0d9b045c4e6..8476b834909 100644 --- a/ortools/sat/swig_helper.cc +++ b/ortools/sat/swig_helper.cc @@ -92,7 +92,8 @@ void SolutionCallback::StopSearch() { if (wrapper_ != nullptr) wrapper_->StopSearch(); } -operations_research::sat::CpSolverResponse SolutionCallback::Response() const { +const operations_research::sat::CpSolverResponse& SolutionCallback::Response() + const { return response_; } diff --git a/ortools/sat/swig_helper.h b/ortools/sat/swig_helper.h index 3a9cfeec691..6d32e51b7db 100644 --- a/ortools/sat/swig_helper.h +++ b/ortools/sat/swig_helper.h @@ -65,7 +65,7 @@ class SolutionCallback { // Stops the search. void StopSearch(); - operations_research::sat::CpSolverResponse Response() const; + const operations_research::sat::CpSolverResponse& Response() const; // We use mutable and non const methods to overcome SWIG difficulties. void SetWrapperClass(SolveWrapper* wrapper) const;