forked from probsys/hierarchical-irm
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into 061024-thomaswc-ssize
- Loading branch information
Showing
12 changed files
with
789 additions
and
515 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
// Copyright 2020 | ||
// See LICENSE.txt | ||
|
||
#pragma once | ||
#include <cassert> | ||
#include <string> | ||
#include <unordered_set> | ||
|
||
#include "distributions/crp.hh" | ||
|
||
typedef int T_item; | ||
|
||
class Domain { | ||
public: | ||
const std::string name; // human-readable name | ||
std::unordered_set<T_item> items; // set of items | ||
CRP crp; // clustering model for items | ||
std::mt19937* prng; | ||
|
||
Domain(const std::string& name, std::mt19937* prng) : name(name), crp(prng) { | ||
assert(!name.empty()); | ||
this->prng = prng; | ||
} | ||
void incorporate(const T_item& item, int table = -1) { | ||
if (items.contains(item)) { | ||
assert(table == -1); | ||
} else { | ||
items.insert(item); | ||
int t = 0 <= table ? table : crp.sample(); | ||
crp.incorporate(item, t); | ||
} | ||
} | ||
void unincorporate(const T_item& item) { | ||
printf("Not implemented\n"); | ||
exit(EXIT_FAILURE); | ||
// assert(items.count(item) == 1); | ||
// assert(items.at(item).count(relation) == 1); | ||
// items.at(item).erase(relation); | ||
// if (items.at(item).size() == 0) { | ||
// crp.unincorporate(item); | ||
// items.erase(item); | ||
// } | ||
} | ||
int get_cluster_assignment(const T_item& item) const { | ||
assert(items.contains(item)); | ||
return crp.assignments.at(item); | ||
} | ||
void set_cluster_assignment_gibbs(const T_item& item, int table) { | ||
assert(items.contains(item)); | ||
assert(crp.assignments.at(item) != table); | ||
crp.unincorporate(item); | ||
crp.incorporate(item, table); | ||
} | ||
std::unordered_map<int, double> tables_weights() const { | ||
return crp.tables_weights(); | ||
} | ||
std::unordered_map<int, double> tables_weights_gibbs( | ||
const T_item& item) const { | ||
int table = get_cluster_assignment(item); | ||
return crp.tables_weights_gibbs(table); | ||
} | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
// Apache License, Version 2.0, refer to LICENSE.txt | ||
|
||
#define BOOST_TEST_MODULE test Domain | ||
|
||
#include "domain.hh" | ||
|
||
#include <boost/test/included/unit_test.hpp> | ||
namespace tt = boost::test_tools; | ||
|
||
BOOST_AUTO_TEST_CASE(test_domain) { | ||
std::mt19937 prng; | ||
Domain d("fruit", &prng); | ||
std::string relation1 = "grows_on_trees"; | ||
std::string relation2 = "is_same_color"; | ||
T_item banana = 1; | ||
T_item apple = 2; | ||
d.incorporate(banana); | ||
d.set_cluster_assignment_gibbs(banana, 12); | ||
d.incorporate(banana); | ||
d.incorporate(apple, 5); | ||
BOOST_TEST(d.items.contains(banana)); | ||
BOOST_TEST(d.items.contains(apple)); | ||
BOOST_TEST(d.items.size() == 2); | ||
|
||
int cb = d.get_cluster_assignment(banana); | ||
int ca = d.get_cluster_assignment(apple); | ||
BOOST_TEST(ca == 5); | ||
BOOST_TEST(cb == 12); | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
#pragma once | ||
|
||
#include <cassert> | ||
#include "emissions/base.hh" | ||
|
||
// A *deterministic* Emission class that always emits not(clean). | ||
// Most users will want to combine this with Sometimes. | ||
class BitFlip : public Emission<bool> { | ||
public: | ||
BitFlip() {}; | ||
|
||
void incorporate(const std::pair<bool, bool>& x) { | ||
assert(x.first != x.second); | ||
++N; | ||
} | ||
|
||
void unincorporate(const std::pair<bool, bool>& x) { | ||
assert(x.first != x.second); | ||
--N; | ||
} | ||
|
||
double logp(const std::pair<bool, bool>& x) const { | ||
assert(x.first != x.second); | ||
return 0.0; | ||
} | ||
|
||
double logp_score() const { | ||
return 0.0; | ||
} | ||
|
||
// No hyperparameters to transition! | ||
void transition_hyperparameters() {} | ||
|
||
bool sample_corrupted(const bool& clean, std::mt19937* unused_prng) { | ||
return !clean; | ||
} | ||
|
||
bool propose_clean(const std::vector<bool>& corrupted, | ||
std::mt19937* unused_prng) { | ||
return !corrupted[0]; | ||
} | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
// Apache License, Version 2.0, refer to LICENSE.txt | ||
|
||
#define BOOST_TEST_MODULE test BitFlip | ||
|
||
#include <random> | ||
|
||
#include "emissions/bitflip.hh" | ||
|
||
#include <boost/test/included/unit_test.hpp> | ||
|
||
BOOST_AUTO_TEST_CASE(test_simple) { | ||
BitFlip bf; | ||
|
||
BOOST_TEST(bf.logp_score() == 0.0); | ||
BOOST_TEST(bf.N == 0); | ||
bf.incorporate(std::make_pair<bool, bool>(true, false)); | ||
BOOST_TEST(bf.logp_score() == 0.0); | ||
BOOST_TEST(bf.N == 1); | ||
bf.unincorporate(std::make_pair<bool, bool>(true, false)); | ||
BOOST_TEST(bf.logp_score() == 0.0); | ||
BOOST_TEST(bf.N == 0); | ||
bf.incorporate(std::make_pair<bool, bool>(false, true)); | ||
bf.incorporate(std::make_pair<bool, bool>(false, true)); | ||
BOOST_TEST(bf.logp_score() == 0.0); | ||
BOOST_TEST(bf.N == 2); | ||
|
||
BOOST_TEST(bf.logp(std::make_pair<bool, bool>(true, false)) == 0.0); | ||
|
||
std::mt19937 prng; | ||
BOOST_TEST(bf.sample_corrupted(false, &prng)); | ||
|
||
BOOST_TEST(bf.propose_clean({false, false, false}, &prng)); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
#pragma once | ||
|
||
#include <unordered_map> | ||
|
||
#include "distributions/beta_bernoulli.hh" | ||
#include "emissions/base.hh" | ||
|
||
// An Emission class that sometimes applies BaseEmissor and sometimes doesn't. | ||
// BaseEmissor must (1) of type Emission<SampleType> and (2) assign zero | ||
// probability to <clean, dirty> pairs with clean == dirty. [For example, | ||
// BitFlip and Gaussian both satisfy #2]. | ||
template <typename BaseEmissor, typename SampleType = double> | ||
class Sometimes : public Emission<SampleType> { | ||
public: | ||
BetaBernoulli bb; | ||
BaseEmissor be; | ||
|
||
Sometimes() : bb(nullptr) {}; | ||
|
||
void incorporate(const std::pair<SampleType, SampleType>& x) { | ||
++(this->N); | ||
bb.incorporate(x.first != x.second); | ||
if (x.first != x.second) { | ||
be.incorporate(x); | ||
} | ||
} | ||
|
||
void unincorporate(const std::pair<SampleType, SampleType>& x) { | ||
--(this->N); | ||
bb.unincorporate(x.first != x.second); | ||
if (x.first != x.second) { | ||
be.unincorporate(x); | ||
} | ||
} | ||
|
||
double logp(const std::pair<SampleType, SampleType>& x) const { | ||
return bb.logp(x.first != x.second) + be.logp(x); | ||
} | ||
|
||
double logp_score() const { | ||
return bb.logp_score() + be.logp_score(); | ||
} | ||
|
||
void transition_hyperparameters() { | ||
be.transition_hyperparameters(); | ||
bb.transition_hyperparameters(); | ||
} | ||
|
||
SampleType sample_corrupted(const SampleType& clean, std::mt19937* prng) { | ||
bb.prng = prng; | ||
if (bb.sample()) { | ||
return be.sample_corrupted(clean, prng); | ||
} | ||
return clean; | ||
} | ||
|
||
SampleType propose_clean(const std::vector<SampleType>& corrupted, | ||
std::mt19937* prng) { | ||
// We approximate the maximum likelihood estimate by taking the mode of | ||
// corrupted. The full solution would construct BaseEmissor and | ||
// BetaBernoulli instances for each choice of clean and picking the | ||
// clean with the highest combined logp_score(). | ||
std::unordered_map<SampleType, int> counts; | ||
SampleType mode; | ||
int max_count = 0; | ||
for (const SampleType& c: corrupted) { | ||
++counts[c]; | ||
if (counts[c] > max_count) { | ||
max_count = counts[c]; | ||
mode = c; | ||
} | ||
} | ||
return mode; | ||
} | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
// Apache License, Version 2.0, refer to LICENSE.txt | ||
|
||
#define BOOST_TEST_MODULE test Sometimes | ||
|
||
#include <random> | ||
|
||
#include "emissions/bitflip.hh" | ||
#include "emissions/sometimes.hh" | ||
|
||
#include <boost/test/included/unit_test.hpp> | ||
|
||
BOOST_AUTO_TEST_CASE(test_simple) { | ||
Sometimes<BitFlip, bool> sbf; | ||
|
||
double orig_lp = sbf.logp_score(); | ||
BOOST_TEST(sbf.N == 0); | ||
sbf.incorporate(std::make_pair<bool, bool>(true, false)); | ||
BOOST_TEST(sbf.logp_score() < 0.0); | ||
BOOST_TEST(sbf.N == 1); | ||
sbf.unincorporate(std::make_pair<bool, bool>(true, false)); | ||
BOOST_TEST(sbf.logp_score() == orig_lp); | ||
BOOST_TEST(sbf.N == 0); | ||
|
||
sbf.incorporate(std::make_pair<bool, bool>(false, true)); | ||
sbf.incorporate(std::make_pair<bool, bool>(false, true)); | ||
BOOST_TEST(sbf.logp_score() < 0.0); | ||
BOOST_TEST(sbf.N == 2); | ||
|
||
BOOST_TEST(sbf.logp(std::make_pair<bool, bool>(true, false)) < 0.0); | ||
|
||
std::mt19937 prng; | ||
BOOST_TEST(sbf.propose_clean({true, true, false}, &prng)); | ||
} |
Oops, something went wrong.