Skip to content

Commit

Permalink
[CP-SAT] polish python linear expr code
Browse files Browse the repository at this point in the history
  • Loading branch information
lperron committed Dec 29, 2024
1 parent 2b22356 commit ea2e1b6
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 143 deletions.
89 changes: 45 additions & 44 deletions ortools/sat/python/linear_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ LinearExpr* LinearExpr::Sum(const std::vector<LinearExpr*>& exprs) {
}
}

LinearExpr* LinearExpr::Sum(const std::vector<ExprOrValue>& exprs) {
LinearExpr* LinearExpr::MixedSum(const std::vector<ExprOrValue>& exprs) {
std::vector<LinearExpr*> lin_exprs;
int64_t int_offset = 0;
double double_offset = 0.0;
Expand Down Expand Up @@ -84,48 +84,26 @@ LinearExpr* LinearExpr::Sum(const std::vector<ExprOrValue>& exprs) {
}
}

LinearExpr* LinearExpr::WeightedSum(const std::vector<LinearExpr*>& exprs,
const std::vector<double>& coeffs) {
if (exprs.empty()) return new FloatConstant(0.0);
if (exprs.size() == 1) {
return new FloatAffine(exprs[0], coeffs[0], 0.0);
}
return new FloatWeightedSum(exprs, coeffs, 0.0);
}

LinearExpr* LinearExpr::WeightedSum(const std::vector<LinearExpr*>& exprs,
const std::vector<int64_t>& coeffs) {
LinearExpr* LinearExpr::WeightedSumInt(const std::vector<LinearExpr*>& exprs,
const std::vector<int64_t>& coeffs) {
if (exprs.empty()) return new IntConstant(0);
if (exprs.size() == 1) {
return new IntAffine(exprs[0], coeffs[0], 0);
}
return new IntWeightedSum(exprs, coeffs, 0);
}

LinearExpr* LinearExpr::WeightedSum(const std::vector<ExprOrValue>& exprs,
const std::vector<double>& coeffs) {
std::vector<LinearExpr*> lin_exprs;
std::vector<double> lin_coeffs;
double cst = 0.0;
for (int i = 0; i < exprs.size(); ++i) {
if (exprs[i].expr != nullptr) {
lin_exprs.push_back(exprs[i].expr);
lin_coeffs.push_back(coeffs[i]);
} else {
cst += coeffs[i] *
(exprs[i].double_value + static_cast<double>(exprs[i].int_value));
}
}

if (lin_exprs.empty()) return new FloatConstant(cst);
if (lin_exprs.size() == 1) {
return new FloatAffine(lin_exprs[0], lin_coeffs[0], cst);
LinearExpr* LinearExpr::WeightedSumDouble(const std::vector<LinearExpr*>& exprs,
const std::vector<double>& coeffs) {
if (exprs.empty()) return new FloatConstant(0.0);
if (exprs.size() == 1) {
return new FloatAffine(exprs[0], coeffs[0], 0.0);
}
return new FloatWeightedSum(lin_exprs, lin_coeffs, cst);
return new FloatWeightedSum(exprs, coeffs, 0.0);
}

LinearExpr* LinearExpr::WeightedSum(const std::vector<ExprOrValue>& exprs,
const std::vector<int64_t>& coeffs) {
LinearExpr* LinearExpr::MixedWeightedSumInt(
const std::vector<ExprOrValue>& exprs, const std::vector<int64_t>& coeffs) {
std::vector<LinearExpr*> lin_exprs;
std::vector<int64_t> lin_coeffs;
int64_t int_cst = 0;
Expand Down Expand Up @@ -162,33 +140,56 @@ LinearExpr* LinearExpr::WeightedSum(const std::vector<ExprOrValue>& exprs,
return new IntWeightedSum(lin_exprs, lin_coeffs, int_cst);
}

LinearExpr* LinearExpr::Term(LinearExpr* expr, double coeff) {
return new FloatAffine(expr, coeff, 0.0);
LinearExpr* LinearExpr::MixedWeightedSumDouble(
const std::vector<ExprOrValue>& exprs, const std::vector<double>& coeffs) {
std::vector<LinearExpr*> lin_exprs;
std::vector<double> lin_coeffs;
double cst = 0.0;
for (int i = 0; i < exprs.size(); ++i) {
if (exprs[i].expr != nullptr) {
lin_exprs.push_back(exprs[i].expr);
lin_coeffs.push_back(coeffs[i]);
} else {
cst += coeffs[i] *
(exprs[i].double_value + static_cast<double>(exprs[i].int_value));
}
}

if (lin_exprs.empty()) return new FloatConstant(cst);
if (lin_exprs.size() == 1) {
return new FloatAffine(lin_exprs[0], lin_coeffs[0], cst);
}
return new FloatWeightedSum(lin_exprs, lin_coeffs, cst);
}

LinearExpr* LinearExpr::Term(LinearExpr* expr, int64_t coeff) {
LinearExpr* LinearExpr::TermInt(LinearExpr* expr, int64_t coeff) {
return new IntAffine(expr, coeff, 0);
}

LinearExpr* LinearExpr::Affine(LinearExpr* expr, double coeff, double offset) {
if (coeff == 1.0 && offset == 0.0) return expr;
return new FloatAffine(expr, coeff, offset);
LinearExpr* LinearExpr::TermDouble(LinearExpr* expr, double coeff) {
return new FloatAffine(expr, coeff, 0.0);
}

LinearExpr* LinearExpr::Affine(LinearExpr* expr, int64_t coeff,
int64_t offset) {
LinearExpr* LinearExpr::AffineInt(LinearExpr* expr, int64_t coeff,
int64_t offset) {
if (coeff == 1 && offset == 0) return expr;
return new IntAffine(expr, coeff, offset);
}

LinearExpr* LinearExpr::Constant(double value) {
return new FloatConstant(value);
LinearExpr* LinearExpr::AffineDouble(LinearExpr* expr, double coeff,
double offset) {
if (coeff == 1.0 && offset == 0.0) return expr;
return new FloatAffine(expr, coeff, offset);
}

LinearExpr* LinearExpr::Constant(int64_t value) {
LinearExpr* LinearExpr::ConstantInt(int64_t value) {
return new IntConstant(value);
}

LinearExpr* LinearExpr::ConstantDouble(double value) {
return new FloatConstant(value);
}

LinearExpr* LinearExpr::Add(LinearExpr* expr) {
return new BinaryAdd(this, expr);
}
Expand Down
31 changes: 16 additions & 15 deletions ortools/sat/python/linear_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,22 @@ class LinearExpr {
virtual std::string DebugString() const { return "LinearExpr()"; }

static LinearExpr* Sum(const std::vector<LinearExpr*>& exprs);
static LinearExpr* Sum(const std::vector<ExprOrValue>& exprs);
static LinearExpr* WeightedSum(const std::vector<LinearExpr*>& exprs,
const std::vector<int64_t>& coeffs);
static LinearExpr* WeightedSum(const std::vector<LinearExpr*>& exprs,
const std::vector<double>& coeffs);
static LinearExpr* WeightedSum(const std::vector<ExprOrValue>& exprs,
const std::vector<int64_t>& coeffs);
static LinearExpr* WeightedSum(const std::vector<ExprOrValue>& exprs,
const std::vector<double>& coeffs);
static LinearExpr* Term(LinearExpr* expr, int64_t coeff);
static LinearExpr* Term(LinearExpr* expr, double coeff);
static LinearExpr* Affine(LinearExpr* expr, int64_t coeff, int64_t offset);
static LinearExpr* Affine(LinearExpr* expr, double coeff, double offset);
static LinearExpr* Constant(int64_t value);
static LinearExpr* Constant(double value);
static LinearExpr* MixedSum(const std::vector<ExprOrValue>& exprs);
static LinearExpr* WeightedSumInt(const std::vector<LinearExpr*>& exprs,
const std::vector<int64_t>& coeffs);
static LinearExpr* WeightedSumDouble(const std::vector<LinearExpr*>& exprs,
const std::vector<double>& coeffs);
static LinearExpr* MixedWeightedSumInt(const std::vector<ExprOrValue>& exprs,
const std::vector<int64_t>& coeffs);
static LinearExpr* MixedWeightedSumDouble(
const std::vector<ExprOrValue>& exprs, const std::vector<double>& coeffs);
static LinearExpr* TermInt(LinearExpr* expr, int64_t coeff);
static LinearExpr* TermDouble(LinearExpr* expr, double coeff);
static LinearExpr* AffineInt(LinearExpr* expr, int64_t coeff, int64_t offset);
static LinearExpr* AffineDouble(LinearExpr* expr, double coeff,
double offset);
static LinearExpr* ConstantInt(int64_t value);
static LinearExpr* ConstantDouble(double value);

LinearExpr* Add(LinearExpr* expr);
LinearExpr* AddInt(int64_t cst);
Expand Down
134 changes: 50 additions & 84 deletions ortools/sat/python/swig_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -313,95 +313,65 @@ PYBIND11_MODULE(swig_helper, m) {
.def_readonly("expr", &ExprOrValue::expr)
.def_readonly("int_value", &ExprOrValue::int_value);

py::implicitly_convertible<LinearExpr*, ExprOrValue>();
py::implicitly_convertible<double, ExprOrValue>();
py::implicitly_convertible<int64_t, ExprOrValue>();
py::implicitly_convertible<double, ExprOrValue>();
py::implicitly_convertible<LinearExpr*, ExprOrValue>();

py::class_<LinearExpr>(m, "LinearExpr", kLinearExprClassDoc)
// We make sure to keep the order of the overloads: LinearExpr* before
// ExprOrValue as this is faster to parse and type check.
.def_static(
"sum",
py::overload_cast<const std::vector<LinearExpr*>&>(&LinearExpr::Sum),
py::return_value_policy::automatic, py::keep_alive<0, 1>())
.def_static(
"sum",
py::overload_cast<const std::vector<ExprOrValue>&>(&LinearExpr::Sum),
py::return_value_policy::automatic, py::keep_alive<0, 1>())
.def_static("weighted_sum",
py::overload_cast<const std::vector<LinearExpr*>&,
const std::vector<int64_t>&>(
&LinearExpr::WeightedSum),
.def_static("sum", (&LinearExpr::Sum), arg("exprs"),
py::return_value_policy::automatic, py::keep_alive<0, 1>())
.def_static("weighted_sum",
py::overload_cast<const std::vector<LinearExpr*>&,
const std::vector<double>&>(
&LinearExpr::WeightedSum),
.def_static("sum", &LinearExpr::MixedSum, arg("exprs"),
py::return_value_policy::automatic, py::keep_alive<0, 1>())
.def_static("weighted_sum",
py::overload_cast<const std::vector<ExprOrValue>&,
const std::vector<int64_t>&>(
&LinearExpr::WeightedSum),
.def_static("weighted_sum", &LinearExpr::WeightedSumInt, arg("exprs"),
arg("coeffs"), py::return_value_policy::automatic,
py::keep_alive<0, 1>())
.def_static("weighted_sum", &LinearExpr::WeightedSumDouble, arg("exprs"),
arg("coeffs"), py::return_value_policy::automatic,
py::keep_alive<0, 1>())
.def_static("weighted_sum", &LinearExpr::MixedWeightedSumInt,
arg("exprs"), arg("coeffs"),
py::return_value_policy::automatic, py::keep_alive<0, 1>())
.def_static("weighted_sum",
py::overload_cast<const std::vector<ExprOrValue>&,
const std::vector<double>&>(
&LinearExpr::WeightedSum),
.def_static("weighted_sum", &LinearExpr::MixedWeightedSumDouble,
arg("exprs"), arg("coeffs"),
py::return_value_policy::automatic, py::keep_alive<0, 1>())
// Make sure to keep the order of the overloads: int before float as an
// an integer value will be silently converted to a float.
.def_static("term",
py::overload_cast<LinearExpr*, int64_t>(&LinearExpr::Term),
arg("expr"), arg("coeff"), "Returns expr * coeff.",
.def_static("term", &LinearExpr::TermInt, arg("expr").none(false),
arg("coeff"), "Returns expr * coeff.",
py::return_value_policy::automatic, py::keep_alive<0, 1>())
.def_static("term",
py::overload_cast<LinearExpr*, double>(&LinearExpr::Term),
arg("expr"), arg("coeff"), "Returns expr * coeff.",
.def_static("term", &LinearExpr::TermDouble, arg("expr").none(false),
arg("coeff"), "Returns expr * coeff.",
py::return_value_policy::automatic, py::keep_alive<0, 1>())
.def_static(
"affine",
py::overload_cast<LinearExpr*, int64_t, int64_t>(&LinearExpr::Affine),
arg("expr"), arg("coeff"), arg("offset"),
"Returns expr * coeff + offset.", py::return_value_policy::automatic,
py::keep_alive<0, 1>())
.def_static(
"affine",
py::overload_cast<LinearExpr*, double, double>(&LinearExpr::Affine),
arg("expr"), arg("coeff"), arg("offset"),
"Returns expr * coeff + offset.", py::return_value_policy::automatic,
py::keep_alive<0, 1>())
.def_static("constant", py::overload_cast<int64_t>(&LinearExpr::Constant),
arg("value"), "Returns a constant linear expression.",
.def_static("affine", &LinearExpr::AffineInt, arg("expr").none(false),
arg("coeff"), arg("offset"), "Returns expr * coeff + offset.",
py::return_value_policy::automatic, py::keep_alive<0, 1>())
.def_static("affine", &LinearExpr::AffineDouble, arg("expr").none(false),
arg("coeff"), arg("offset"), "Returns expr * coeff + offset.",
py::return_value_policy::automatic, py::keep_alive<0, 1>())
.def_static("constant", &LinearExpr::ConstantInt, arg("value"),
"Returns a constant linear expression.",
py::return_value_policy::automatic)
.def_static("constant", py::overload_cast<double>(&LinearExpr::Constant),
arg("value"), "Returns a constant linear expression.",
.def_static("constant", &LinearExpr::ConstantDouble, arg("value"),
"Returns a constant linear expression.",
py::return_value_policy::automatic)
.def_static(
"Sum",
py::overload_cast<const std::vector<LinearExpr*>&>(&LinearExpr::Sum),
py::return_value_policy::automatic, py::keep_alive<0, 1>())
.def_static(
"Sum",
py::overload_cast<const std::vector<ExprOrValue>&>(&LinearExpr::Sum),
py::return_value_policy::automatic, py::keep_alive<0, 1>())
.def_static("WeightedSum",
py::overload_cast<const std::vector<ExprOrValue>&,
const std::vector<int64_t>&>(
&LinearExpr::WeightedSum),
.def_static("Sum", &LinearExpr::Sum, arg("exprs"),
py::return_value_policy::automatic, py::keep_alive<0, 1>())
.def_static("WeightedSum",
py::overload_cast<const std::vector<ExprOrValue>&,
const std::vector<double>&>(
&LinearExpr::WeightedSum),
.def_static("Sum", &LinearExpr::MixedSum, arg("exprs"),
py::return_value_policy::automatic, py::keep_alive<0, 1>())
.def_static("WeightedSum", &LinearExpr::MixedWeightedSumInt, arg("exprs"),
arg("coeffs"), py::return_value_policy::automatic,
py::keep_alive<0, 1>())
.def_static("WeightedSum", &LinearExpr::MixedWeightedSumDouble,
arg("exprs"), arg("coeffs"),
py::return_value_policy::automatic, py::keep_alive<0, 1>())
.def_static("Term", &LinearExpr::TermInt, arg("expr").none(false),
arg("coeff"), "Returns expr * coeff.",
py::return_value_policy::automatic, py::keep_alive<0, 1>())
.def_static("Term", &LinearExpr::TermDouble, arg("expr").none(false),
arg("coeff"), "Returns expr * coeff.",
py::return_value_policy::automatic, py::keep_alive<0, 1>())
.def_static(
"Term", py::overload_cast<LinearExpr*, int64_t>(&LinearExpr::Term),
arg("expr").none(false), arg("coeff"), "Returns expr * coeff.",
py::return_value_policy::automatic, py::keep_alive<0, 1>())
.def_static(
"Term", py::overload_cast<LinearExpr*, double>(&LinearExpr::Term),
arg("expr").none(false), arg("coeff"), "Returns expr * coeff.",
py::return_value_policy::automatic, py::keep_alive<0, 1>())
.def("__str__", &LinearExpr::ToString)
.def("__repr__", &LinearExpr::DebugString)
.def("is_integer", &LinearExpr::IsInteger)
Expand All @@ -427,18 +397,14 @@ PYBIND11_MODULE(swig_helper, m) {
py::return_value_policy::automatic, py::keep_alive<0, 1>())
.def("__rsub__", &LinearExpr::RSubDouble, arg("cst"),
py::return_value_policy::automatic, py::keep_alive<0, 1>())
.def("__mul__", py::overload_cast<int64_t>(&LinearExpr::MulInt),
arg("cst"), py::return_value_policy::automatic,
py::keep_alive<0, 1>())
.def("__mul__", py::overload_cast<double>(&LinearExpr::MulDouble),
arg("cst"), py::return_value_policy::automatic,
py::keep_alive<0, 1>())
.def("__rmul__", py::overload_cast<int64_t>(&LinearExpr::MulInt),
arg("cst"), py::return_value_policy::automatic,
py::keep_alive<0, 1>())
.def("__rmul__", py::overload_cast<double>(&LinearExpr::MulDouble),
arg("cst"), py::return_value_policy::automatic,
py::keep_alive<0, 1>())
.def("__mul__", &LinearExpr::MulInt, arg("cst"),
py::return_value_policy::automatic, py::keep_alive<0, 1>())
.def("__mul__", &LinearExpr::MulDouble, arg("cst"),
py::return_value_policy::automatic, py::keep_alive<0, 1>())
.def("__rmul__", &LinearExpr::MulInt, arg("cst"),
py::return_value_policy::automatic, py::keep_alive<0, 1>())
.def("__rmul__", &LinearExpr::MulDouble, arg("cst"),
py::return_value_policy::automatic, py::keep_alive<0, 1>())
.def("__neg__", &LinearExpr::Neg, py::return_value_policy::automatic,
py::keep_alive<0, 1>())
.def(
Expand Down

0 comments on commit ea2e1b6

Please sign in to comment.