Skip to content

Commit

Permalink
Merge branch 'master' into 061024-thomaswc-ssize
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasColthurst committed Jun 12, 2024
2 parents e20d77d + cf1039a commit 3a9490d
Show file tree
Hide file tree
Showing 12 changed files with 789 additions and 515 deletions.
39 changes: 39 additions & 0 deletions cxx/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ cc_library(
deps = [],
)

cc_library(
name = "domain",
hdrs = ["domain.hh"],
deps = [
"//distributions",
],
)

cc_binary(
name = "hirm",
srcs = ["hirm.cc"],
Expand All @@ -29,6 +37,17 @@ cc_binary(
],
)

cc_library(
name = "relation",
hdrs = ["relation.hh"],
deps = [
":domain",
":util_hash",
":util_math",
"//distributions:base"
],
)

cc_library(
name = "util_hash",
hdrs = ["util_hash.hh"],
Expand Down Expand Up @@ -57,6 +76,26 @@ cc_library(
deps = [":headers"],
)

cc_test(
name = "domain_test",
srcs = ["domain_test.cc"],
deps = [
":domain",
"@boost//:test",
],
)

cc_test(
name = "relation_test",
srcs = ["relation_test.cc"],
deps = [
":domain",
":relation",
"//distributions",
"@boost//:test",
],
)

cc_test(
name = "util_math_test",
srcs = ["util_math_test.cc"],
Expand Down
62 changes: 62 additions & 0 deletions cxx/domain.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright 2020
// See LICENSE.txt

#pragma once
#include <cassert>
#include <string>
#include <unordered_set>

#include "distributions/crp.hh"

typedef int T_item;

class Domain {
public:
const std::string name; // human-readable name
std::unordered_set<T_item> items; // set of items
CRP crp; // clustering model for items
std::mt19937* prng;

Domain(const std::string& name, std::mt19937* prng) : name(name), crp(prng) {
assert(!name.empty());
this->prng = prng;
}
void incorporate(const T_item& item, int table = -1) {
if (items.contains(item)) {
assert(table == -1);
} else {
items.insert(item);
int t = 0 <= table ? table : crp.sample();
crp.incorporate(item, t);
}
}
void unincorporate(const T_item& item) {
printf("Not implemented\n");
exit(EXIT_FAILURE);
// assert(items.count(item) == 1);
// assert(items.at(item).count(relation) == 1);
// items.at(item).erase(relation);
// if (items.at(item).size() == 0) {
// crp.unincorporate(item);
// items.erase(item);
// }
}
int get_cluster_assignment(const T_item& item) const {
assert(items.contains(item));
return crp.assignments.at(item);
}
void set_cluster_assignment_gibbs(const T_item& item, int table) {
assert(items.contains(item));
assert(crp.assignments.at(item) != table);
crp.unincorporate(item);
crp.incorporate(item, table);
}
std::unordered_map<int, double> tables_weights() const {
return crp.tables_weights();
}
std::unordered_map<int, double> tables_weights_gibbs(
const T_item& item) const {
int table = get_cluster_assignment(item);
return crp.tables_weights_gibbs(table);
}
};
30 changes: 30 additions & 0 deletions cxx/domain_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Apache License, Version 2.0, refer to LICENSE.txt

#define BOOST_TEST_MODULE test Domain

#include "domain.hh"

#include <boost/test/included/unit_test.hpp>
namespace tt = boost::test_tools;

BOOST_AUTO_TEST_CASE(test_domain) {
std::mt19937 prng;
Domain d("fruit", &prng);
std::string relation1 = "grows_on_trees";
std::string relation2 = "is_same_color";
T_item banana = 1;
T_item apple = 2;
d.incorporate(banana);
d.set_cluster_assignment_gibbs(banana, 12);
d.incorporate(banana);
d.incorporate(apple, 5);
BOOST_TEST(d.items.contains(banana));
BOOST_TEST(d.items.contains(apple));
BOOST_TEST(d.items.size() == 2);

int cb = d.get_cluster_assignment(banana);
int ca = d.get_cluster_assignment(apple);
BOOST_TEST(ca == 5);
BOOST_TEST(cb == 12);

}
38 changes: 38 additions & 0 deletions cxx/emissions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ cc_library(
],
)

cc_library(
name = "bitflip",
srcs = ["bitflip.hh"],
visibility = ["//:__subpackages__"],
deps = [":base"],
)

cc_library(
name = "gaussian",
srcs = ["gaussian.hh"],
Expand All @@ -19,6 +26,37 @@ cc_library(
],
)

cc_library(
name = "sometimes",
srcs = ["sometimes.hh"],
visibility = ["//:__subpackages__"],
deps = [
":base",
"//distributions:beta_bernoulli",
],
)

cc_test(
name = "bitflip_test",
srcs = ["bitflip_test.cc"],
deps = [
":bitflip",
"@boost//:algorithm",
"@boost//:test",
],
)

cc_test(
name = "sometimes_test",
srcs = ["sometimes_test.cc"],
deps = [
":bitflip",
":sometimes",
"@boost//:algorithm",
"@boost//:test",
],
)

# TODO(thomaswc): Fix and re-enable.
#cc_test(
# name = "gaussian_test",
Expand Down
42 changes: 42 additions & 0 deletions cxx/emissions/bitflip.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#pragma once

#include <cassert>
#include "emissions/base.hh"

// A *deterministic* Emission class that always emits not(clean).
// Most users will want to combine this with Sometimes.
class BitFlip : public Emission<bool> {
public:
BitFlip() {};

void incorporate(const std::pair<bool, bool>& x) {
assert(x.first != x.second);
++N;
}

void unincorporate(const std::pair<bool, bool>& x) {
assert(x.first != x.second);
--N;
}

double logp(const std::pair<bool, bool>& x) const {
assert(x.first != x.second);
return 0.0;
}

double logp_score() const {
return 0.0;
}

// No hyperparameters to transition!
void transition_hyperparameters() {}

bool sample_corrupted(const bool& clean, std::mt19937* unused_prng) {
return !clean;
}

bool propose_clean(const std::vector<bool>& corrupted,
std::mt19937* unused_prng) {
return !corrupted[0];
}
};
33 changes: 33 additions & 0 deletions cxx/emissions/bitflip_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Apache License, Version 2.0, refer to LICENSE.txt

#define BOOST_TEST_MODULE test BitFlip

#include <random>

#include "emissions/bitflip.hh"

#include <boost/test/included/unit_test.hpp>

BOOST_AUTO_TEST_CASE(test_simple) {
BitFlip bf;

BOOST_TEST(bf.logp_score() == 0.0);
BOOST_TEST(bf.N == 0);
bf.incorporate(std::make_pair<bool, bool>(true, false));
BOOST_TEST(bf.logp_score() == 0.0);
BOOST_TEST(bf.N == 1);
bf.unincorporate(std::make_pair<bool, bool>(true, false));
BOOST_TEST(bf.logp_score() == 0.0);
BOOST_TEST(bf.N == 0);
bf.incorporate(std::make_pair<bool, bool>(false, true));
bf.incorporate(std::make_pair<bool, bool>(false, true));
BOOST_TEST(bf.logp_score() == 0.0);
BOOST_TEST(bf.N == 2);

BOOST_TEST(bf.logp(std::make_pair<bool, bool>(true, false)) == 0.0);

std::mt19937 prng;
BOOST_TEST(bf.sample_corrupted(false, &prng));

BOOST_TEST(bf.propose_clean({false, false, false}, &prng));
}
75 changes: 75 additions & 0 deletions cxx/emissions/sometimes.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#pragma once

#include <unordered_map>

#include "distributions/beta_bernoulli.hh"
#include "emissions/base.hh"

// An Emission class that sometimes applies BaseEmissor and sometimes doesn't.
// BaseEmissor must (1) of type Emission<SampleType> and (2) assign zero
// probability to <clean, dirty> pairs with clean == dirty. [For example,
// BitFlip and Gaussian both satisfy #2].
template <typename BaseEmissor, typename SampleType = double>
class Sometimes : public Emission<SampleType> {
public:
BetaBernoulli bb;
BaseEmissor be;

Sometimes() : bb(nullptr) {};

void incorporate(const std::pair<SampleType, SampleType>& x) {
++(this->N);
bb.incorporate(x.first != x.second);
if (x.first != x.second) {
be.incorporate(x);
}
}

void unincorporate(const std::pair<SampleType, SampleType>& x) {
--(this->N);
bb.unincorporate(x.first != x.second);
if (x.first != x.second) {
be.unincorporate(x);
}
}

double logp(const std::pair<SampleType, SampleType>& x) const {
return bb.logp(x.first != x.second) + be.logp(x);
}

double logp_score() const {
return bb.logp_score() + be.logp_score();
}

void transition_hyperparameters() {
be.transition_hyperparameters();
bb.transition_hyperparameters();
}

SampleType sample_corrupted(const SampleType& clean, std::mt19937* prng) {
bb.prng = prng;
if (bb.sample()) {
return be.sample_corrupted(clean, prng);
}
return clean;
}

SampleType propose_clean(const std::vector<SampleType>& corrupted,
std::mt19937* prng) {
// We approximate the maximum likelihood estimate by taking the mode of
// corrupted. The full solution would construct BaseEmissor and
// BetaBernoulli instances for each choice of clean and picking the
// clean with the highest combined logp_score().
std::unordered_map<SampleType, int> counts;
SampleType mode;
int max_count = 0;
for (const SampleType& c: corrupted) {
++counts[c];
if (counts[c] > max_count) {
max_count = counts[c];
mode = c;
}
}
return mode;
}
};
33 changes: 33 additions & 0 deletions cxx/emissions/sometimes_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Apache License, Version 2.0, refer to LICENSE.txt

#define BOOST_TEST_MODULE test Sometimes

#include <random>

#include "emissions/bitflip.hh"
#include "emissions/sometimes.hh"

#include <boost/test/included/unit_test.hpp>

BOOST_AUTO_TEST_CASE(test_simple) {
Sometimes<BitFlip, bool> sbf;

double orig_lp = sbf.logp_score();
BOOST_TEST(sbf.N == 0);
sbf.incorporate(std::make_pair<bool, bool>(true, false));
BOOST_TEST(sbf.logp_score() < 0.0);
BOOST_TEST(sbf.N == 1);
sbf.unincorporate(std::make_pair<bool, bool>(true, false));
BOOST_TEST(sbf.logp_score() == orig_lp);
BOOST_TEST(sbf.N == 0);

sbf.incorporate(std::make_pair<bool, bool>(false, true));
sbf.incorporate(std::make_pair<bool, bool>(false, true));
BOOST_TEST(sbf.logp_score() < 0.0);
BOOST_TEST(sbf.N == 2);

BOOST_TEST(sbf.logp(std::make_pair<bool, bool>(true, false)) < 0.0);

std::mt19937 prng;
BOOST_TEST(sbf.propose_clean({true, true, false}, &prng));
}
Loading

0 comments on commit 3a9490d

Please sign in to comment.