Skip to content

Commit

Permalink
added initial h2 heuristic..
Browse files Browse the repository at this point in the history
  • Loading branch information
riccardodebenedictis committed Nov 21, 2023
1 parent 786b098 commit 7247e9d
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 9 deletions.
6 changes: 5 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ set(TEMPORAL_NETWORK_TYPES DL LA)
set(TEMPORAL_NETWORK_TYPE LA CACHE STRING "Temporal network type")
set_property(CACHE TEMPORAL_NETWORK_TYPE PROPERTY STRINGS ${TEMPORAL_NETWORK_TYPES})

set(HEURISTIC_TYPES h_max h_add)
set(HEURISTIC_TYPES h_max h_add h2_max h2_add)
set(HEURISTIC_TYPE h_max CACHE STRING "Heuristic type")
set_property(CACHE HEURISTIC_TYPE PROPERTY STRINGS ${HEURISTIC_TYPES})

Expand Down Expand Up @@ -63,6 +63,10 @@ if(HEURISTIC_TYPE STREQUAL h_max)
target_compile_definitions(${PROJECT_NAME} PRIVATE H_MAX)
elseif(HEURISTIC_TYPE STREQUAL h_add)
target_compile_definitions(${PROJECT_NAME} PRIVATE H_ADD)
elseif(HEURISTIC_TYPE STREQUAL h2_max)
target_compile_definitions(${PROJECT_NAME} PRIVATE H2_MAX)
elseif(HEURISTIC_TYPE STREQUAL h2_add)
target_compile_definitions(${PROJECT_NAME} PRIVATE H2_ADD)
else()
message(FATAL_ERROR "HEURISTIC_TYPE must be one of ${HEURISTIC_TYPES}")
endif()
Expand Down
2 changes: 1 addition & 1 deletion include/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ namespace ratio
virtual void push() {}
virtual void pop() {}

protected:
virtual void activated_flaw(flaw &) {}
virtual void negated_flaw(flaw &f) { propagate_costs(f); }
virtual void activated_resolver(resolver &) {}
virtual void negated_resolver(resolver &r);

protected:
void new_flaw(flaw_ptr f, const bool &enqueue = true) const noexcept;
const std::unordered_map<semitone::var, std::vector<flaw_ptr>> &get_flaws() const noexcept;
const std::unordered_set<flaw *> &get_active_flaws() const noexcept;
Expand Down
12 changes: 9 additions & 3 deletions include/heuristics/h_1.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace ratio
class enum_flaw;
class atom_flaw;

class h_1 final : public graph
class h_1 : public graph
{
public:
h_1(solver &s);
Expand All @@ -27,19 +27,25 @@ namespace ratio
#endif

#ifdef GRAPH_REFINING
void refine() override;
protected:
virtual void refine() override;

private:
void prune_enums();
#endif

bool is_deferrable(flaw &f); // checks whether the given flaw is deferrable..

private:
std::deque<flaw *> flaw_q; // the flaw queue (for the graph building procedure)..
std::deque<flaw *> flaw_q; // the flaw queue (for the graph building procedure)..
protected:
std::unordered_set<flaw *> visited; // the visited flaws, for graph cost propagation (and deferrable flaws check)..
#ifdef GRAPH_PRUNING
private:
std::unordered_set<flaw *> already_closed; // already closed flaws (for avoiding duplicating graph pruning constraints)..
#endif
#ifdef GRAPH_REFINING
private:
std::vector<enum_flaw *> enum_flaws; // the enum flaws..
std::unordered_set<atom_flaw *> landmarks; // the possible landmarks..
#endif
Expand Down
52 changes: 52 additions & 0 deletions include/heuristics/h_2.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#pragma once

#include "h_1.h"

namespace ratio
{
class h_2 : public h_1
{
public:
h_2(solver &s);

#ifdef GRAPH_REFINING
void refine() override;

void visit(flaw &f);

void negated_resolver(resolver &r) override;

class h_2_flaw : public flaw
{
public:
h_2_flaw(flaw &sub_f, resolver &r, resolver &mtx_r);

private:
void compute_resolvers() override;

json::json get_data() const noexcept override;

class h_2_resolver : public resolver
{
public:
h_2_resolver(h_2_flaw &f, resolver &sub_r);

void apply() override;

json::json get_data() const noexcept override;

private:
resolver &sub_r;
};

private:
flaw &sub_f;
resolver &mtx_r;
};

private:
resolver *c_res = nullptr;
std::vector<flaw *> h_2_flaws;
#endif
};
} // namespace ratio
3 changes: 3 additions & 0 deletions include/resolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ namespace ratio

friend json::json to_json(const resolver &r) noexcept;

protected:
void new_causal_link(flaw &f);

private:
flaw &f; // the flaw solved by this resolver..
const semitone::lit rho; // the propositional literal indicating whether the resolver is active or not..
Expand Down
1 change: 0 additions & 1 deletion include/solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ namespace ratio
{
friend class flaw;
friend class resolver;
friend class atom_flaw;
friend class graph;
friend class smart_type;
#ifdef BUILD_LISTENERS
Expand Down
4 changes: 3 additions & 1 deletion src/flaws/atom_flaw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ namespace ratio
auto u_res = new unify_atom(*this, i, eq_lit);
assert(get_solver().get_sat_core().value(u_res->get_rho()) != utils::False);
add_resolver(u_res);
get_solver().new_causal_link(*t_atm.reason, *u_res);
}

if (c_atm.is_fact())
Expand Down Expand Up @@ -102,6 +101,9 @@ namespace ratio
assert(get_solver().get_sat_core().value(t_atm.sigma) != utils::False); // the target atom must be activable..
assert(get_solver().get_sat_core().value(get_rho()) != utils::False); // this resolver must be activable..

// we add a causal link from the target atom's reason to this resolver..
new_causal_link(t_atm.get_reason());

assert(t_atm.reason->is_expanded());
for (auto &r : t_atm.reason->get_resolvers())
if (dynamic_cast<activate_fact *>(&r.get()) || dynamic_cast<activate_goal *>(&r.get())) // we disable this unification if the target atom is not activable..
Expand Down
78 changes: 78 additions & 0 deletions src/heuristics/h_2.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#include "h_2.h"
#include "solver.h"
#include <cassert>

namespace ratio
{
h_2::h_2(solver &s) : h_1(s) {}

#ifdef GRAPH_REFINING
void h_2::refine()
{
// we refine the graph..
h_1::refine();

// we visit the flaws..
for (auto &f : std::vector<flaw *>(get_active_flaws().begin(), get_active_flaws().end()))
visit(*f);

assert(s.get_sat_core().root_level());

// we add the h_2 flaws..
for (auto &f : h_2_flaws)
new_flaw(f, false);
}

void h_2::visit(flaw &f)
{
// we visit the flaw..
visited.insert(&f);

// we check whether the best resolver is actually solvable..
c_res = &f.get_best_resolver();
s.take_decision(f.get_best_resolver().get_rho());

// we visit the subflaws..
for (auto &p : c_res->get_preconditions())
if (!visited.count(&p.get()))
visit(p.get());

if (s.get_sat_core().value(c_res->get_rho()) != utils::True)
s.get_sat_core().pop();

// we unvisit the flaw..
visited.erase(&f);
}

void h_2::negated_resolver(resolver &r)
{
// resolver c_res is mutex with r!

// we refine the graph..
h_1::negated_resolver(r);
}

h_2::h_2_flaw::h_2_flaw(flaw &sub_f, resolver &r, resolver &mtx_r) : flaw(sub_f.get_solver(), {r}), sub_f(sub_f), mtx_r(mtx_r) {}

void h_2::h_2_flaw::compute_resolvers()
{
// we compute the resolvers..
for (auto &r : sub_f.get_resolvers())
if (&r.get() != &mtx_r)
add_resolver(new h_2_resolver(*this, r.get()));
}

json::json h_2::h_2_flaw::get_data() const noexcept { return {{"type", "h_2"}, {"flaw", variable(sub_f.get_phi())}}; }

h_2::h_2_flaw::h_2_resolver::h_2_resolver(h_2_flaw &f, resolver &sub_r) : resolver(f, sub_r.get_rho(), sub_r.get_intrinsic_cost()), sub_r(sub_r) {}

void h_2::h_2_flaw::h_2_resolver::apply()
{
// we apply the resolver..
for (auto &p : sub_r.get_preconditions())
new_causal_link(p.get());
}

json::json h_2::h_2_flaw::h_2_resolver::get_data() const noexcept { return {{"type", "h_2"}, {"resolver", variable(sub_r.get_rho())}}; }
#endif
} // namespace ratio
6 changes: 4 additions & 2 deletions src/resolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ namespace ratio
return intrinsic_cost;

utils::rational est_cost;
#ifdef H_MAX
#if defined(H_MAX) || defined(H2_MAX)
est_cost = utils::rational::NEGATIVE_INFINITY;
for (const auto &p : preconditions)
if (!p.get().is_expanded())
return utils::rational::POSITIVE_INFINITY;
else // we compute the max of the flaws' estimated costs..
est_cost = std::max(est_cost, p.get().get_estimated_cost());
#endif
#ifdef H_ADD
#if defined(H_ADD) || defined(H2_ADD)
est_cost = utils::rational::ZERO;
for (const auto &p : preconditions)
if (!p.get().is_expanded())
Expand All @@ -35,6 +35,8 @@ namespace ratio
return est_cost + intrinsic_cost;
}

void resolver::new_causal_link(flaw &f) { f.s.new_causal_link(f, *this); }

std::string to_string(const resolver &r) noexcept
{
std::string state;
Expand Down

0 comments on commit 7247e9d

Please sign in to comment.