Skip to content

Commit

Permalink
[CP-SAT] polish python layer
Browse files Browse the repository at this point in the history
  • Loading branch information
lperron committed Dec 30, 2024
1 parent 080f445 commit 3b84e02
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 327 deletions.
1 change: 1 addition & 0 deletions ortools/sat/python/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pybind_extension(
deps = [
":linear_expr",
"//ortools/sat:cp_model_cc_proto",
"//ortools/sat:cp_model_utils",
"//ortools/sat:sat_parameters_cc_proto",
"//ortools/sat:swig_helper",
"@com_google_absl//absl/strings",
Expand Down
4 changes: 2 additions & 2 deletions ortools/sat/python/cp_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,10 +1167,10 @@ def testRepr(self) -> None:
z = model.new_int_var(0, 3, "z")
self.assertEqual(repr(x), "x(0..4)")
self.assertEqual(repr(x * 2), "IntAffine(expr=x(0..4), coeff=2, offset=0)")
self.assertEqual(repr(x + y), "BinaryAdd(x(0..4), y(0..3))")
self.assertEqual(repr(x + y), "SumArray(x(0..4), y(0..3))")
self.assertEqual(
repr(cp_model.LinearExpr.sum([x, y, z])),
"IntSum(x(0..4), y(0..3), z(0..3), 0)",
"SumArray(x(0..4), y(0..3), z(0..3))",
)
self.assertEqual(
repr(cp_model.LinearExpr.weighted_sum([x, y, 2], [1, 2, 3])),
Expand Down
8 changes: 4 additions & 4 deletions ortools/sat/python/linear_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ LinearExpr* LinearExpr::Sum(const std::vector<LinearExpr*>& exprs) {
} else if (exprs.size() == 1) {
return exprs[0];
} else {
return new IntSum(exprs, 0);
return new SumArray(exprs);
}
}

Expand Down Expand Up @@ -70,7 +70,7 @@ LinearExpr* LinearExpr::MixedSum(const std::vector<ExprOrValue>& exprs) {
} else if (lin_exprs.size() == 1) {
return new IntAffine(lin_exprs[0], 1, int_offset);
} else {
return new IntSum(lin_exprs, int_offset);
return new SumArray(lin_exprs, int_offset);
}
} else { // General floating point case.
double_offset += static_cast<double>(int_offset);
Expand All @@ -79,7 +79,7 @@ LinearExpr* LinearExpr::MixedSum(const std::vector<ExprOrValue>& exprs) {
} else if (lin_exprs.size() == 1) {
return new FloatAffine(lin_exprs[0], 1.0, double_offset);
} else {
return new FloatSum(lin_exprs, double_offset);
return new SumArray(lin_exprs, 0, double_offset);
}
}
}
Expand Down Expand Up @@ -191,7 +191,7 @@ LinearExpr* LinearExpr::ConstantDouble(double value) {
}

LinearExpr* LinearExpr::Add(LinearExpr* expr) {
return new BinaryAdd(this, expr);
return new SumArray({this, expr});
}

LinearExpr* LinearExpr::AddInt(int64_t cst) {
Expand Down
150 changes: 46 additions & 104 deletions ortools/sat/python/linear_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,57 +173,44 @@ class CanonicalIntExpression {
bool ok_;
};

class BinaryAdd : public LinearExpr {
// A class to hold a sum of linear expressions, and optional integer and
// double offsets.
class SumArray : public LinearExpr {
public:
BinaryAdd(LinearExpr* lhs, LinearExpr* rhs) : lhs_(lhs), rhs_(rhs) {}
~BinaryAdd() override = default;

void VisitAsFloat(FloatExprVisitor* lin, double c) override {
lin->AddToProcess(lhs_, c);
lin->AddToProcess(rhs_, c);
}
explicit SumArray(const std::vector<LinearExpr*>& exprs,
int64_t int_offset = 0, double double_offset = 0.0)
: exprs_(exprs.begin(), exprs.end()),
int_offset_(int_offset),
double_offset_(double_offset) {}
~SumArray() override = default;

bool VisitAsInt(IntExprVisitor* lin, int64_t c) override {
lin->AddToProcess(lhs_, c);
lin->AddToProcess(rhs_, c);
if (double_offset_ != 0.0) return false;
for (int i = 0; i < exprs_.size(); ++i) {
lin->AddToProcess(exprs_[i], c);
}
lin->AddConstant(int_offset_ * c);
return true;
}

std::string ToString() const override {
return absl::StrCat("(", lhs_->ToString(), " + ", rhs_->ToString(), ")");
}

std::string DebugString() const override {
return absl::StrCat("BinaryAdd(", lhs_->DebugString(), ", ",
rhs_->DebugString(), ")");
}

private:
LinearExpr* lhs_;
LinearExpr* rhs_;
};

// A class to hold a sum of floating point linear expressions.
class FloatSum : public LinearExpr {
public:
FloatSum(const std::vector<LinearExpr*>& exprs, double offset)
: exprs_(exprs.begin(), exprs.end()), offset_(offset) {}
~FloatSum() override = default;

bool VisitAsInt(IntExprVisitor* /*lin*/, int64_t /*c*/) override {
return false;
}

void VisitAsFloat(FloatExprVisitor* lin, double c) override {
for (int i = 0; i < exprs_.size(); ++i) {
lin->AddToProcess(exprs_[i], c);
}
lin->AddConstant(offset_ * c);
if (int_offset_ != 0) {
lin->AddConstant(int_offset_ * c);
} else if (double_offset_ != 0.0) {
lin->AddConstant(double_offset_ * c);
}
}

std::string ToString() const override {
if (exprs_.empty()) {
return absl::StrCat(offset_);
if (double_offset_ != 0.0) {
return absl::StrCat(double_offset_);
} else {
return absl::StrCat(int_offset_);
}
}
std::string s = "(";
for (int i = 0; i < exprs_.size(); ++i) {
Expand All @@ -232,89 +219,44 @@ class FloatSum : public LinearExpr {
}
absl::StrAppend(&s, exprs_[i]->ToString());
}
if (offset_ != 0.0) {
if (offset_ > 0.0) {
absl::StrAppend(&s, " + ", offset_);
if (double_offset_ != 0.0) {
if (double_offset_ > 0.0) {
absl::StrAppend(&s, " + ", double_offset_);
} else {
absl::StrAppend(&s, " - ", -offset_);
absl::StrAppend(&s, " - ", -double_offset_);
}
}
if (int_offset_ != 0) {
if (int_offset_ > 0) {
absl::StrAppend(&s, " + ", int_offset_);
} else {
absl::StrAppend(&s, " - ", -int_offset_);
}
}
absl::StrAppend(&s, ")");
return s;
}

std::string DebugString() const override {
return absl::StrCat("FloatSum(",
absl::StrJoin(exprs_, ", ",
[](std::string* out, LinearExpr* expr) {
absl::StrAppend(out,
expr->DebugString());
}),
", ", offset_, ")");
}

private:
const absl::FixedArray<LinearExpr*, 2> exprs_;
double offset_;
};

// A class to hold a sum of integer linear expressions.
class IntSum : public LinearExpr {
public:
IntSum(const std::vector<LinearExpr*>& exprs, int64_t offset)
: exprs_(exprs.begin(), exprs.end()), offset_(offset) {}
~IntSum() override = default;

bool VisitAsInt(IntExprVisitor* lin, int64_t c) override {
for (int i = 0; i < exprs_.size(); ++i) {
lin->AddToProcess(exprs_[i], c);
}
lin->AddConstant(offset_ * c);
return true;
}

void VisitAsFloat(FloatExprVisitor* lin, double c) override {
for (int i = 0; i < exprs_.size(); ++i) {
lin->AddToProcess(exprs_[i], c);
}
lin->AddConstant(offset_ * c);
}

std::string ToString() const override {
if (exprs_.empty()) {
return absl::StrCat(offset_);
std::string s = absl::StrCat(
"SumArray(",
absl::StrJoin(exprs_, ", ", [](std::string* out, LinearExpr* expr) {
absl::StrAppend(out, expr->DebugString());
}));
if (int_offset_ != 0) {
absl::StrAppend(&s, ", int_offset=", int_offset_);
}
std::string s = "(";
for (int i = 0; i < exprs_.size(); ++i) {
if (i > 0) {
absl::StrAppend(&s, " + ");
}
absl::StrAppend(&s, exprs_[i]->ToString());
}
if (offset_ != 0) {
if (offset_ > 0) {
absl::StrAppend(&s, " + ", offset_);
} else {
absl::StrAppend(&s, " - ", -offset_);
}
if (double_offset_ != 0.0) {
absl::StrAppend(&s, ", double_offset=", double_offset_);
}
absl::StrAppend(&s, ")");
return s;
}

std::string DebugString() const override {
return absl::StrCat("IntSum(",
absl::StrJoin(exprs_, ", ",
[](std::string* out, LinearExpr* expr) {
absl::StrAppend(out,
expr->DebugString());
}),
", ", offset_, ")");
}

private:
const absl::FixedArray<LinearExpr*, 2> exprs_;
int64_t offset_;
const int64_t int_offset_;
const double double_offset_;
};

// A class to hold a weighted sum of floating point linear expressions.
Expand Down
Loading

0 comments on commit 3b84e02

Please sign in to comment.