Skip to content

Commit

Permalink
Merge pull request #43 from probcomp/061024-emilyaf-breakup-hirm
Browse files Browse the repository at this point in the history
Move Domain and Relation from hirm.hh to separate files.
  • Loading branch information
emilyfertig authored Jun 10, 2024
2 parents 7b9cfe7 + 7c5fe7a commit a03fa0b
Show file tree
Hide file tree
Showing 7 changed files with 568 additions and 515 deletions.
39 changes: 39 additions & 0 deletions cxx/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ cc_library(
deps = [],
)

cc_library(
name = "domain",
hdrs = ["domain.hh"],
deps = [
"//distributions",
],
)

cc_binary(
name = "hirm",
srcs = ["hirm.cc"],
Expand All @@ -29,6 +37,17 @@ cc_binary(
],
)

cc_library(
name = "relation",
hdrs = ["relation.hh"],
deps = [
":domain",
":util_hash",
":util_math",
"//distributions:base"
],
)

cc_library(
name = "util_hash",
hdrs = ["util_hash.hh"],
Expand Down Expand Up @@ -57,6 +76,26 @@ cc_library(
deps = [":headers"],
)

cc_test(
name = "domain_test",
srcs = ["domain_test.cc"],
deps = [
":domain",
"@boost//:test",
],
)

cc_test(
name = "relation_test",
srcs = ["relation_test.cc"],
deps = [
":domain",
":relation",
"//distributions",
"@boost//:test",
],
)

cc_test(
name = "util_math_test",
srcs = ["util_math_test.cc"],
Expand Down
62 changes: 62 additions & 0 deletions cxx/domain.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright 2020
// See LICENSE.txt

#pragma once
#include <cassert>
#include <string>
#include <unordered_set>

#include "distributions/crp.hh"

typedef int T_item;

class Domain {
public:
const std::string name; // human-readable name
std::unordered_set<T_item> items; // set of items
CRP crp; // clustering model for items
std::mt19937* prng;

Domain(const std::string& name, std::mt19937* prng) : name(name), crp(prng) {
assert(!name.empty());
this->prng = prng;
}
void incorporate(const T_item& item, int table = -1) {
if (items.contains(item)) {
assert(table == -1);
} else {
items.insert(item);
int t = 0 <= table ? table : crp.sample();
crp.incorporate(item, t);
}
}
void unincorporate(const T_item& item) {
printf("Not implemented\n");
exit(EXIT_FAILURE);
// assert(items.count(item) == 1);
// assert(items.at(item).count(relation) == 1);
// items.at(item).erase(relation);
// if (items.at(item).size() == 0) {
// crp.unincorporate(item);
// items.erase(item);
// }
}
int get_cluster_assignment(const T_item& item) const {
assert(items.contains(item));
return crp.assignments.at(item);
}
void set_cluster_assignment_gibbs(const T_item& item, int table) {
assert(items.contains(item));
assert(crp.assignments.at(item) != table);
crp.unincorporate(item);
crp.incorporate(item, table);
}
std::unordered_map<int, double> tables_weights() const {
return crp.tables_weights();
}
std::unordered_map<int, double> tables_weights_gibbs(
const T_item& item) const {
int table = get_cluster_assignment(item);
return crp.tables_weights_gibbs(table);
}
};
30 changes: 30 additions & 0 deletions cxx/domain_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Apache License, Version 2.0, refer to LICENSE.txt

#define BOOST_TEST_MODULE test Domain

#include "domain.hh"

#include <boost/test/included/unit_test.hpp>
namespace tt = boost::test_tools;

BOOST_AUTO_TEST_CASE(test_domain) {
std::mt19937 prng;
Domain d("fruit", &prng);
std::string relation1 = "grows_on_trees";
std::string relation2 = "is_same_color";
T_item banana = 1;
T_item apple = 2;
d.incorporate(banana);
d.set_cluster_assignment_gibbs(banana, 12);
d.incorporate(banana);
d.incorporate(apple, 5);
BOOST_TEST(d.items.contains(banana));
BOOST_TEST(d.items.contains(apple));
BOOST_TEST(d.items.size() == 2);

int cb = d.get_cluster_assignment(banana);
int ca = d.get_cluster_assignment(apple);
BOOST_TEST(ca == 5);
BOOST_TEST(cb == 12);

}
Loading

0 comments on commit a03fa0b

Please sign in to comment.