diff --git a/cxx/BUILD b/cxx/BUILD index 14de76d..811cc71 100644 --- a/cxx/BUILD +++ b/cxx/BUILD @@ -16,6 +16,14 @@ cc_library( deps = [], ) +cc_library( + name = "domain", + hdrs = ["domain.hh"], + deps = [ + "//distributions", + ], +) + cc_binary( name = "hirm", srcs = ["hirm.cc"], @@ -29,6 +37,17 @@ cc_binary( ], ) +cc_library( + name = "relation", + hdrs = ["relation.hh"], + deps = [ + ":domain", + ":util_hash", + ":util_math", + "//distributions:base" + ], +) + cc_library( name = "util_hash", hdrs = ["util_hash.hh"], @@ -57,6 +76,26 @@ cc_library( deps = [":headers"], ) +cc_test( + name = "domain_test", + srcs = ["domain_test.cc"], + deps = [ + ":domain", + "@boost//:test", + ], +) + +cc_test( + name = "relation_test", + srcs = ["relation_test.cc"], + deps = [ + ":domain", + ":relation", + "//distributions", + "@boost//:test", + ], +) + cc_test( name = "util_math_test", srcs = ["util_math_test.cc"], diff --git a/cxx/domain.hh b/cxx/domain.hh new file mode 100644 index 0000000..68cc938 --- /dev/null +++ b/cxx/domain.hh @@ -0,0 +1,62 @@ +// Copyright 2020 +// See LICENSE.txt + +#pragma once +#include +#include +#include + +#include "distributions/crp.hh" + +typedef int T_item; + +class Domain { + public: + const std::string name; // human-readable name + std::unordered_set items; // set of items + CRP crp; // clustering model for items + std::mt19937* prng; + + Domain(const std::string& name, std::mt19937* prng) : name(name), crp(prng) { + assert(!name.empty()); + this->prng = prng; + } + void incorporate(const T_item& item, int table = -1) { + if (items.contains(item)) { + assert(table == -1); + } else { + items.insert(item); + int t = 0 <= table ? table : crp.sample(); + crp.incorporate(item, t); + } + } + void unincorporate(const T_item& item) { + printf("Not implemented\n"); + exit(EXIT_FAILURE); + // assert(items.count(item) == 1); + // assert(items.at(item).count(relation) == 1); + // items.at(item).erase(relation); + // if (items.at(item).size() == 0) { + // crp.unincorporate(item); + // items.erase(item); + // } + } + int get_cluster_assignment(const T_item& item) const { + assert(items.contains(item)); + return crp.assignments.at(item); + } + void set_cluster_assignment_gibbs(const T_item& item, int table) { + assert(items.contains(item)); + assert(crp.assignments.at(item) != table); + crp.unincorporate(item); + crp.incorporate(item, table); + } + std::unordered_map tables_weights() const { + return crp.tables_weights(); + } + std::unordered_map tables_weights_gibbs( + const T_item& item) const { + int table = get_cluster_assignment(item); + return crp.tables_weights_gibbs(table); + } +}; diff --git a/cxx/domain_test.cc b/cxx/domain_test.cc new file mode 100644 index 0000000..3c09516 --- /dev/null +++ b/cxx/domain_test.cc @@ -0,0 +1,30 @@ +// Apache License, Version 2.0, refer to LICENSE.txt + +#define BOOST_TEST_MODULE test Domain + +#include "domain.hh" + +#include +namespace tt = boost::test_tools; + +BOOST_AUTO_TEST_CASE(test_domain) { + std::mt19937 prng; + Domain d("fruit", &prng); + std::string relation1 = "grows_on_trees"; + std::string relation2 = "is_same_color"; + T_item banana = 1; + T_item apple = 2; + d.incorporate(banana); + d.set_cluster_assignment_gibbs(banana, 12); + d.incorporate(banana); + d.incorporate(apple, 5); + BOOST_TEST(d.items.contains(banana)); + BOOST_TEST(d.items.contains(apple)); + BOOST_TEST(d.items.size() == 2); + + int cb = d.get_cluster_assignment(banana); + int ca = d.get_cluster_assignment(apple); + BOOST_TEST(ca == 5); + BOOST_TEST(cb == 12); + +} \ No newline at end of file diff --git a/cxx/hirm.hh b/cxx/hirm.hh index 8890dde..083f910 100644 --- a/cxx/hirm.hh +++ b/cxx/hirm.hh @@ -14,429 +14,14 @@ #include "distributions/crp.hh" #include "distributions/dirichlet_categorical.hh" #include "distributions/normal.hh" -#include "util_hash.hh" -#include "util_math.hh" +#include "relation.hh" -typedef int T_item; -typedef std::vector T_items; -typedef VectorIntHash H_items; - -// T_relation is the text we get from reading a line of the schema file; -// hirm.hh:Relation is the object that does the work. -class T_relation { - public: - // The relation is a map from the domains to the space .distribution - // is a distribution over. - std::vector domains; - - // Must be the name of a distribution in distributions/. - std::string distribution; -}; // Map from names to T_relation's. typedef std::map T_schema; using ObservationVariant = std::variant; -class Domain { - public: - const std::string name; // human-readable name - std::unordered_set items; // set of items - CRP crp; // clustering model for items - std::mt19937* prng; - - Domain(const std::string& name, std::mt19937* prng) : name(name), crp(prng) { - assert(!name.empty()); - this->prng = prng; - } - void incorporate(const T_item& item, int table = -1) { - if (items.contains(item)) { - assert(table == -1); - } else { - items.insert(item); - int t = 0 <= table ? table : crp.sample(); - crp.incorporate(item, t); - } - } - void unincorporate(const T_item& item) { - printf("Not implemented\n"); - exit(EXIT_FAILURE); - // assert(items.count(item) == 1); - // assert(items.at(item).count(relation) == 1); - // items.at(item).erase(relation); - // if (items.at(item).size() == 0) { - // crp.unincorporate(item); - // items.erase(item); - // } - } - int get_cluster_assignment(const T_item& item) const { - assert(items.contains(item)); - return crp.assignments.at(item); - } - void set_cluster_assignment_gibbs(const T_item& item, int table) { - assert(items.contains(item)); - assert(crp.assignments.at(item) != table); - crp.unincorporate(item); - crp.incorporate(item, table); - } - std::unordered_map tables_weights() const { - return crp.tables_weights(); - } - std::unordered_map tables_weights_gibbs( - const T_item& item) const { - int table = get_cluster_assignment(item); - return crp.tables_weights_gibbs(table); - } -}; - -template -class Relation { - public: - using ValueType = typename DistributionType::SampleType; - using DType = DistributionType; - static_assert(std::is_base_of, DType>::value, - "DistributionType must inherit from Distribution."); - // human-readable name - const std::string name; - // Distribution over the relation's codomain. - const std::string distribution; - // list of domain pointers - const std::vector domains; - // map from cluster multi-index to Distribution pointer - std::unordered_map, DistributionType*, VectorIntHash> - clusters; - // map from item to observed data - std::unordered_map data; - // map from domain name to reverse map from item to - // set of items that include that item - std::unordered_map< - std::string, - std::unordered_map>> - data_r; - std::mt19937* prng; - - Relation(const std::string& name, const std::string& distribution, - const std::vector& domains, std::mt19937* prng) - : name(name), distribution(distribution), domains(domains) { - assert(!domains.empty()); - assert(!name.empty()); - this->prng = prng; - for (const Domain* const d : domains) { - this->data_r[d->name] = - std::unordered_map>(); - } - } - - ~Relation() { - for (auto [z, cluster] : clusters) { - delete cluster; - } - } - - T_relation get_T_relation() { - T_relation trel; - trel.distribution = distribution; - for (const auto& d : domains) { - trel.domains.push_back(d->name); - } - return trel; - } - - void incorporate(const T_items& items, ValueType value) { - assert(!data.contains(items)); - data[items] = value; - for (int i = 0; i < std::ssize(domains); ++i) { - domains[i]->incorporate(items[i]); - if (!data_r.at(domains[i]->name).contains(items[i])) { - data_r.at(domains[i]->name)[items[i]] = - std::unordered_set(); - } - data_r.at(domains[i]->name).at(items[i]).insert(items); - } - T_items z = get_cluster_assignment(items); - if (!clusters.contains(z)) { - // Invalid discussion as using pointers now; - // Cannot use clusters[z] because BetaBernoulli - // does not have a default constructor, whereas operator[] - // calls default constructor when the key does not exist. - clusters[z] = new DistributionType(prng); - } - clusters.at(z)->incorporate(value); - } - - void unincorporate(const T_items& items) { - printf("Not implemented\n"); - exit(EXIT_FAILURE); - // auto x = data.at(items); - // auto z = get_cluster_assignment(items); - // clusters.at(z)->unincorporate(x); - // if (clusters.at(z)->N == 0) { - // delete clusters.at(z); - // clusters.erase(z); - // } - // for (int i = 0; i < domains.size(); i++) { - // const std::string &n = domains[i]->name; - // if (data_r.at(n).count(items[i]) > 0) { - // data_r.at(n).at(items[i]).erase(items); - // if (data_r.at(n).at(items[i]).size() == 0) { - // data_r.at(n).erase(items[i]); - // domains[i]->unincorporate(name, items[i]); - // } - // } - // } - // data.erase(items); - } - - std::vector get_cluster_assignment(const T_items& items) const { - assert(items.size() == domains.size()); - std::vector z(domains.size()); - for (int i = 0; i < std::ssize(domains); ++i) { - z[i] = domains[i]->get_cluster_assignment(items[i]); - } - return z; - } - - std::vector get_cluster_assignment_gibbs(const T_items& items, - const Domain& domain, - const T_item& item, - int table) const { - assert(items.size() == domains.size()); - std::vector z(domains.size()); - int hits = 0; - for (int i = 0; i < std::ssize(domains); ++i) { - if ((domains[i]->name == domain.name) && (items[i] == item)) { - z[i] = table; - ++hits; - } else { - z[i] = domains[i]->get_cluster_assignment(items[i]); - } - } - assert(hits > 0); - return z; - } - - // Implementation of approximate Gibbs data probabilities (faster). - - double logp_gibbs_approx_current(const Domain& domain, const T_item& item) { - double logp = 0.; - for (const T_items& items : data_r.at(domain.name).at(item)) { - ValueType x = data.at(items); - T_items z = get_cluster_assignment(items); - auto cluster = clusters.at(z); - cluster->unincorporate(x); - double lp = cluster->logp(x); - cluster->incorporate(x); - logp += lp; - } - return logp; - } - - double logp_gibbs_approx_variant(const Domain& domain, const T_item& item, - int table) { - double logp = 0.; - for (const T_items& items : data_r.at(domain.name).at(item)) { - ValueType x = data.at(items); - T_items z = get_cluster_assignment_gibbs(items, domain, item, table); - double lp; - if (!clusters.contains(z)) { - DistributionType cluster(prng); - lp = cluster.logp(x); - } else { - lp = clusters.at(z)->logp(x); - } - logp += lp; - } - return logp; - } - - double logp_gibbs_approx(const Domain& domain, const T_item& item, - int table) { - int table_current = domain.get_cluster_assignment(item); - return table_current == table - ? logp_gibbs_approx_current(domain, item) - : logp_gibbs_approx_variant(domain, item, table); - } - - // Implementation of exact Gibbs data probabilities. - - std::unordered_map const, std::vector, - VectorIntHash> - get_cluster_to_items_list(Domain const& domain, const T_item& item) { - std::unordered_map, std::vector, - VectorIntHash> - m; - for (const T_items& items : data_r.at(domain.name).at(item)) { - T_items z = get_cluster_assignment(items); - m[z].push_back(items); - } - return m; - } - - double logp_gibbs_exact_current(const std::vector& items_list) { - assert(!items_list.empty()); - T_items z = get_cluster_assignment(items_list[0]); - auto cluster = clusters.at(z); - double logp0 = cluster->logp_score(); - for (const T_items& items : items_list) { - ValueType x = data.at(items); - // assert(z == get_cluster_assignment(items)); - cluster->unincorporate(x); - } - double logp1 = cluster->logp_score(); - for (const T_items& items : items_list) { - ValueType x = data.at(items); - cluster->incorporate(x); - } - assert(cluster->logp_score() == logp0); - return logp0 - logp1; - } - - double logp_gibbs_exact_variant(const Domain& domain, const T_item& item, - int table, - const std::vector& items_list) { - assert(!items_list.empty()); - T_items z = - get_cluster_assignment_gibbs(items_list[0], domain, item, table); - - DistributionType aux(prng); - DistributionType* cluster = clusters.contains(z) ? clusters.at(z) : &aux; - // auto cluster = self.clusters.get(z, self.aux()) - double logp0 = cluster->logp_score(); - for (const T_items& items : items_list) { - // assert(z == get_cluster_assignment_gibbs(items, domain, item, table)); - ValueType x = data.at(items); - cluster->incorporate(x); - } - const double logp1 = cluster->logp_score(); - for (const T_items& items : items_list) { - ValueType x = data.at(items); - cluster->unincorporate(x); - } - assert(cluster->logp_score() == logp0); - return logp1 - logp0; - } - - std::vector logp_gibbs_exact(const Domain& domain, const T_item& item, - std::vector tables) { - auto cluster_to_items_list = get_cluster_to_items_list(domain, item); - int table_current = domain.get_cluster_assignment(item); - std::vector logps; // size this? - logps.reserve(tables.size()); - double lp_cluster; - for (const int& table : tables) { - double lp_table = 0; - for (const auto& [z, items_list] : cluster_to_items_list) { - lp_cluster = - (table == table_current) - ? logp_gibbs_exact_current(items_list) - : logp_gibbs_exact_variant(domain, item, table, items_list); - lp_table += lp_cluster; - } - logps.push_back(lp_table); - } - return logps; - } - - double logp(const T_items& items, ValueType value) { - // TODO: Falsely assumes cluster assignments of items - // from same domain are identical, see note in hirm.py - assert(items.size() == domains.size()); - std::vector> tabl_list; - std::vector> wght_list; - std::vector> indx_list; - for (int i = 0; i < std::ssize(domains); ++i) { - Domain* domain = domains.at(i); - T_item item = items.at(i); - std::vector t_list; - std::vector w_list; - std::vector i_list; - if (domain->items.contains(item)) { - int z = domain->get_cluster_assignment(item); - t_list = {z}; - w_list = {0}; - i_list = {0}; - } else { - auto tables_weights = domain->tables_weights(); - double Z = log(domain->crp.alpha + domain->crp.N); - int idx = 0; - for (const auto& [t, w] : tables_weights) { - t_list.push_back(t); - w_list.push_back(log(w) - Z); - i_list.push_back(idx++); - } - assert(idx == std::ssize(t_list)); - } - tabl_list.push_back(t_list); - wght_list.push_back(w_list); - indx_list.push_back(i_list); - } - std::vector logps; - for (const auto& indexes : product(indx_list)) { - assert(indexes.size() == domains.size()); - std::vector z; - z.reserve(domains.size()); - double logp_w = 0; - for (int i = 0; i < std::ssize(domains); ++i) { - T_item zi = tabl_list.at(i).at(indexes[i]); - double wi = wght_list.at(i).at(indexes[i]); - z.push_back(zi); - logp_w += wi; - } - DistributionType aux(prng); - DistributionType* cluster = clusters.contains(z) ? clusters.at(z) : &aux; - double logp_z = cluster->logp(value); - double logp_zw = logp_z + logp_w; - logps.push_back(logp_zw); - } - return logsumexp(logps); - } - - double logp_score() const { - double logp = 0.0; - for (const auto& [_, cluster] : clusters) { - logp += cluster->logp_score(); - } - return logp; - } - - void set_cluster_assignment_gibbs(const Domain& domain, const T_item& item, - int table) { - int table_current = domain.get_cluster_assignment(item); - assert(table != table_current); - for (const T_items& items : data_r.at(domain.name).at(item)) { - ValueType x = data.at(items); - // Remove from current cluster. - T_items z_prev = get_cluster_assignment(items); - auto cluster_prev = clusters.at(z_prev); - cluster_prev->unincorporate(x); - if (cluster_prev->N == 0) { - delete clusters.at(z_prev); - clusters.erase(z_prev); - } - // Move to desired cluster. - T_items z_new = get_cluster_assignment_gibbs(items, domain, item, table); - if (!clusters.contains(z_new)) { - // Move to fresh cluster. - clusters[z_new] = new DistributionType(prng); - clusters.at(z_new)->incorporate(x); - } else { - // Move to existing cluster. - assert((clusters.at(z_new)->N > 0)); - clusters.at(z_new)->incorporate(x); - } - } - // Caller should invoke domain.set_cluster_gibbs - } - - bool has_observation(const Domain& domain, const T_item& item) { - return data_r.at(domain.name).contains(item); - } - - // Disable copying. - Relation& operator=(const Relation&) = delete; - Relation(const Relation&) = delete; -}; - using RelationVariant = std::variant*, Relation*, // Relation*, diff --git a/cxx/relation.hh b/cxx/relation.hh new file mode 100644 index 0000000..4127bd2 --- /dev/null +++ b/cxx/relation.hh @@ -0,0 +1,376 @@ +// Copyright 2020 +// See LICENSE.txt + +#pragma once + +#include +#include +#include + +#include "distributions/base.hh" +#include "domain.hh" +#include "util_hash.hh" +#include "util_math.hh" + +typedef std::vector T_items; +typedef VectorIntHash H_items; + +// T_relation is the text we get from reading a line of the schema file; +// hirm.hh:Relation is the object that does the work. +class T_relation { + public: + // The relation is a map from the domains to the space .distribution + // is a distribution over. + std::vector domains; + + // Must be the name of a distribution in distributions/. + std::string distribution; +}; + +template +class Relation { + public: + using ValueType = typename DistributionType::SampleType; + using DType = DistributionType; + static_assert(std::is_base_of, DType>::value, + "DistributionType must inherit from Distribution."); + // human-readable name + const std::string name; + // Distribution over the relation's codomain. + const std::string distribution; + // list of domain pointers + const std::vector domains; + // map from cluster multi-index to Distribution pointer + std::unordered_map, DistributionType*, VectorIntHash> + clusters; + // map from item to observed data + std::unordered_map data; + // map from domain name to reverse map from item to + // set of items that include that item + std::unordered_map< + std::string, + std::unordered_map>> + data_r; + std::mt19937* prng; + + Relation(const std::string& name, const std::string& distribution, + const std::vector& domains, std::mt19937* prng) + : name(name), distribution(distribution), domains(domains) { + assert(!domains.empty()); + assert(!name.empty()); + this->prng = prng; + for (const Domain* const d : domains) { + this->data_r[d->name] = + std::unordered_map>(); + } + } + + ~Relation() { + for (auto [z, cluster] : clusters) { + delete cluster; + } + } + + T_relation get_T_relation() { + T_relation trel; + trel.distribution = distribution; + for (const auto& d : domains) { + trel.domains.push_back(d->name); + } + return trel; + } + + void incorporate(const T_items& items, ValueType value) { + assert(!data.contains(items)); + data[items] = value; + for (int i = 0; i < std::ssize(domains); ++i) { + domains[i]->incorporate(items[i]); + if (!data_r.at(domains[i]->name).contains(items[i])) { + data_r.at(domains[i]->name)[items[i]] = + std::unordered_set(); + } + data_r.at(domains[i]->name).at(items[i]).insert(items); + } + T_items z = get_cluster_assignment(items); + if (!clusters.contains(z)) { + // Invalid discussion as using pointers now; + // Cannot use clusters[z] because BetaBernoulli + // does not have a default constructor, whereas operator[] + // calls default constructor when the key does not exist. + clusters[z] = new DistributionType(prng); + } + clusters.at(z)->incorporate(value); + } + + void unincorporate(const T_items& items) { + printf("Not implemented\n"); + exit(EXIT_FAILURE); + // auto x = data.at(items); + // auto z = get_cluster_assignment(items); + // clusters.at(z)->unincorporate(x); + // if (clusters.at(z)->N == 0) { + // delete clusters.at(z); + // clusters.erase(z); + // } + // for (int i = 0; i < domains.size(); i++) { + // const std::string &n = domains[i]->name; + // if (data_r.at(n).count(items[i]) > 0) { + // data_r.at(n).at(items[i]).erase(items); + // if (data_r.at(n).at(items[i]).size() == 0) { + // data_r.at(n).erase(items[i]); + // domains[i]->unincorporate(name, items[i]); + // } + // } + // } + // data.erase(items); + } + + std::vector get_cluster_assignment(const T_items& items) const { + assert(items.size() == domains.size()); + std::vector z(domains.size()); + for (int i = 0; i < std::ssize(domains); ++i) { + z[i] = domains[i]->get_cluster_assignment(items[i]); + } + return z; + } + + std::vector get_cluster_assignment_gibbs(const T_items& items, + const Domain& domain, + const T_item& item, + int table) const { + assert(items.size() == domains.size()); + std::vector z(domains.size()); + int hits = 0; + for (int i = 0; i < std::ssize(domains); ++i) { + if ((domains[i]->name == domain.name) && (items[i] == item)) { + z[i] = table; + ++hits; + } else { + z[i] = domains[i]->get_cluster_assignment(items[i]); + } + } + assert(hits > 0); + return z; + } + + // Implementation of approximate Gibbs data probabilities (faster). + + double logp_gibbs_approx_current(const Domain& domain, const T_item& item) { + double logp = 0.; + for (const T_items& items : data_r.at(domain.name).at(item)) { + ValueType x = data.at(items); + T_items z = get_cluster_assignment(items); + auto cluster = clusters.at(z); + cluster->unincorporate(x); + double lp = cluster->logp(x); + cluster->incorporate(x); + logp += lp; + } + return logp; + } + + double logp_gibbs_approx_variant(const Domain& domain, const T_item& item, + int table) { + double logp = 0.; + for (const T_items& items : data_r.at(domain.name).at(item)) { + ValueType x = data.at(items); + T_items z = get_cluster_assignment_gibbs(items, domain, item, table); + double lp; + if (!clusters.contains(z)) { + DistributionType cluster(prng); + lp = cluster.logp(x); + } else { + lp = clusters.at(z)->logp(x); + } + logp += lp; + } + return logp; + } + + double logp_gibbs_approx(const Domain& domain, const T_item& item, + int table) { + int table_current = domain.get_cluster_assignment(item); + return table_current == table + ? logp_gibbs_approx_current(domain, item) + : logp_gibbs_approx_variant(domain, item, table); + } + + // Implementation of exact Gibbs data probabilities. + + std::unordered_map const, std::vector, + VectorIntHash> + get_cluster_to_items_list(Domain const& domain, const T_item& item) { + std::unordered_map, std::vector, + VectorIntHash> + m; + for (const T_items& items : data_r.at(domain.name).at(item)) { + T_items z = get_cluster_assignment(items); + m[z].push_back(items); + } + return m; + } + + double logp_gibbs_exact_current(const std::vector& items_list) { + assert(!items_list.empty()); + T_items z = get_cluster_assignment(items_list[0]); + auto cluster = clusters.at(z); + double logp0 = cluster->logp_score(); + for (const T_items& items : items_list) { + ValueType x = data.at(items); + // assert(z == get_cluster_assignment(items)); + cluster->unincorporate(x); + } + double logp1 = cluster->logp_score(); + for (const T_items& items : items_list) { + ValueType x = data.at(items); + cluster->incorporate(x); + } + assert(cluster->logp_score() == logp0); + return logp0 - logp1; + } + + double logp_gibbs_exact_variant(const Domain& domain, const T_item& item, + int table, + const std::vector& items_list) { + assert(!items_list.empty()); + T_items z = + get_cluster_assignment_gibbs(items_list[0], domain, item, table); + + DistributionType aux(prng); + DistributionType* cluster = clusters.contains(z) ? clusters.at(z) : &aux; + // auto cluster = self.clusters.get(z, self.aux()) + double logp0 = cluster->logp_score(); + for (const T_items& items : items_list) { + // assert(z == get_cluster_assignment_gibbs(items, domain, item, table)); + ValueType x = data.at(items); + cluster->incorporate(x); + } + const double logp1 = cluster->logp_score(); + for (const T_items& items : items_list) { + ValueType x = data.at(items); + cluster->unincorporate(x); + } + assert(cluster->logp_score() == logp0); + return logp1 - logp0; + } + + std::vector logp_gibbs_exact(const Domain& domain, const T_item& item, + std::vector tables) { + auto cluster_to_items_list = get_cluster_to_items_list(domain, item); + int table_current = domain.get_cluster_assignment(item); + std::vector logps; // size this? + logps.reserve(tables.size()); + double lp_cluster; + for (const int& table : tables) { + double lp_table = 0; + for (const auto& [z, items_list] : cluster_to_items_list) { + lp_cluster = + (table == table_current) + ? logp_gibbs_exact_current(items_list) + : logp_gibbs_exact_variant(domain, item, table, items_list); + lp_table += lp_cluster; + } + logps.push_back(lp_table); + } + return logps; + } + + double logp(const T_items& items, ValueType value) { + // TODO: Falsely assumes cluster assignments of items + // from same domain are identical, see note in hirm.py + assert(items.size() == domains.size()); + std::vector> tabl_list; + std::vector> wght_list; + std::vector> indx_list; + for (int i = 0; i < std::ssize(domains); ++i) { + Domain* domain = domains.at(i); + T_item item = items.at(i); + std::vector t_list; + std::vector w_list; + std::vector i_list; + if (domain->items.contains(item)) { + int z = domain->get_cluster_assignment(item); + t_list = {z}; + w_list = {0}; + i_list = {0}; + } else { + auto tables_weights = domain->tables_weights(); + double Z = log(domain->crp.alpha + domain->crp.N); + int idx = 0; + for (const auto& [t, w] : tables_weights) { + t_list.push_back(t); + w_list.push_back(log(w) - Z); + i_list.push_back(idx++); + } + assert(idx == std::ssize(t_list)); + } + tabl_list.push_back(t_list); + wght_list.push_back(w_list); + indx_list.push_back(i_list); + } + std::vector logps; + for (const auto& indexes : product(indx_list)) { + assert(indexes.size() == domains.size()); + std::vector z; + z.reserve(domains.size()); + double logp_w = 0; + for (int i = 0; i < std::ssize(domains); ++i) { + T_item zi = tabl_list.at(i).at(indexes[i]); + double wi = wght_list.at(i).at(indexes[i]); + z.push_back(zi); + logp_w += wi; + } + DistributionType aux(prng); + DistributionType* cluster = clusters.contains(z) ? clusters.at(z) : &aux; + double logp_z = cluster->logp(value); + double logp_zw = logp_z + logp_w; + logps.push_back(logp_zw); + } + return logsumexp(logps); + } + + double logp_score() const { + double logp = 0.0; + for (const auto& [_, cluster] : clusters) { + logp += cluster->logp_score(); + } + return logp; + } + + void set_cluster_assignment_gibbs(const Domain& domain, const T_item& item, + int table) { + int table_current = domain.get_cluster_assignment(item); + assert(table != table_current); + for (const T_items& items : data_r.at(domain.name).at(item)) { + ValueType x = data.at(items); + // Remove from current cluster. + T_items z_prev = get_cluster_assignment(items); + auto cluster_prev = clusters.at(z_prev); + cluster_prev->unincorporate(x); + if (cluster_prev->N == 0) { + delete clusters.at(z_prev); + clusters.erase(z_prev); + } + // Move to desired cluster. + T_items z_new = get_cluster_assignment_gibbs(items, domain, item, table); + if (!clusters.contains(z_new)) { + // Move to fresh cluster. + clusters[z_new] = new DistributionType(prng); + clusters.at(z_new)->incorporate(x); + } else { + // Move to existing cluster. + assert((clusters.at(z_new)->N > 0)); + clusters.at(z_new)->incorporate(x); + } + } + // Caller should invoke domain.set_cluster_gibbs + } + + bool has_observation(const Domain& domain, const T_item& item) { + return data_r.at(domain.name).contains(item); + } + + // Disable copying. + Relation& operator=(const Relation&) = delete; + Relation(const Relation&) = delete; +}; diff --git a/cxx/relation_test.cc b/cxx/relation_test.cc new file mode 100644 index 0000000..dbd8512 --- /dev/null +++ b/cxx/relation_test.cc @@ -0,0 +1,59 @@ +// Apache License, Version 2.0, refer to LICENSE.txt + +#define BOOST_TEST_MODULE test Relation + +#include "relation.hh" + +#include +#include + +#include "distributions/beta_bernoulli.hh" +#include "distributions/bigram.hh" +#include "domain.hh" + +namespace tt = boost::test_tools; + +BOOST_AUTO_TEST_CASE(test_relation) { + std::mt19937 prng; + Domain D1("D1", &prng); + Domain D2("D2", &prng); + Domain D3("D3", &prng); + D1.incorporate(0); + D2.incorporate(1); + D3.incorporate(3); + Relation R1("R1", "bernoulli", {&D1, &D2, &D3}, &prng); + R1.incorporate({0, 1, 3}, 1); + R1.incorporate({1, 1, 3}, 1); + R1.incorporate({3, 1, 3}, 1); + R1.incorporate({4, 1, 3}, 1); + R1.incorporate({5, 1, 3}, 1); + R1.incorporate({0, 1, 4}, 0); + R1.incorporate({0, 1, 6}, 1); + auto z1 = R1.get_cluster_assignment({0, 1, 3}); + BOOST_TEST(z1.size() == 3); + BOOST_TEST(z1[0] == 0); + BOOST_TEST(z1[1] == 0); + BOOST_TEST(z1[2] == 0); + + auto z2 = R1.get_cluster_assignment_gibbs({0, 1, 3}, D2, 1, 191); + BOOST_TEST(z2.size() == 3); + BOOST_TEST(z2[0] == 0); + BOOST_TEST(z2[1] == 191); + BOOST_TEST(z2[2] == 0); + + double lpg = R1.logp_gibbs_approx(D1, 0, 1); + lpg = R1.logp_gibbs_approx(D1, 0, 0); + lpg = R1.logp_gibbs_approx(D1, 0, 10); + R1.set_cluster_assignment_gibbs(D1, 0, 1); + + Relation R2("R1", "bigram", {&D2, &D3}, &prng); + R2.incorporate({1, 3}, "cat"); + R2.incorporate({1, 2}, "dog"); + R2.incorporate({1, 4}, "catt"); + R2.incorporate({2, 6}, "fish"); + + lpg = R2.logp_gibbs_approx(D2, 2, 0); + R2.set_cluster_assignment_gibbs(D3, 3, 1); + D1.set_cluster_assignment_gibbs(0, 1); + +} \ No newline at end of file diff --git a/cxx/tests/test_misc.cc b/cxx/tests/test_misc.cc index 2c457b2..0e664c0 100644 --- a/cxx/tests/test_misc.cc +++ b/cxx/tests/test_misc.cc @@ -14,6 +14,7 @@ #include #include "hirm.hh" +#include "relation.hh" #include "util_hash.hh" #include "util_io.hh" #include "util_math.hh" @@ -22,105 +23,6 @@ int main(int argc, char** argv) { srand(1); std::mt19937 prng(1); - Bigram bg(&prng); - bg.incorporate("foo"); - bg.incorporate("foo"); - bg.incorporate("_Hello!~"); - bg.unincorporate("foo"); - printf("%f\n", exp(bg.logp("bar"))); - printf("%f\n", exp(bg.logp_score())); - for (int i = 0; i < 10; i++) { - printf("%s\n", bg.sample().c_str()); - } - printf("\n"); - - printf("=== DOMAIN === \n"); - Domain d("foo", &prng); - std::string relation1 = "ali"; - std::string relation2 = "mubarak"; - T_item salman = 1; - T_item mansour = 2; - d.incorporate(salman); - for (auto& item : d.items) { - printf("item %d: ", item); - } - d.set_cluster_assignment_gibbs(salman, 12); - d.incorporate(salman); - d.incorporate(mansour, 5); - for (auto& item : d.items) { - printf("item %d: ", item); - } - // d.unincorporate(salman); - for (auto& item : d.items) { - printf("item %d: ", item); - } - // d.unincorporate(relation2, salman); - // assert (d.items.size() == 0); - // d.items[01].insert("foo"); - - std::unordered_map> m; - m[1].insert(10); - m[1] = std::unordered_set(); - for (auto& ir : m) { - printf("%d\n", ir.first); - for (auto& x : ir.second) { - printf("%d\n", x); - } - } - - printf("== RELATION == \n"); - Domain D1("D1", &prng); - Domain D2("D2", &prng); - Domain D3("D3", &prng); - D1.incorporate(0); - D2.incorporate(1); - D3.incorporate(3); - Relation R1("R1", "bernoulli", {&D1, &D2, &D3}, &prng); - printf("arity %ld\n", R1.domains.size()); - R1.incorporate({0, 1, 3}, 1); - R1.incorporate({1, 1, 3}, 1); - R1.incorporate({3, 1, 3}, 1); - R1.incorporate({4, 1, 3}, 1); - R1.incorporate({5, 1, 3}, 1); - R1.incorporate({0, 1, 4}, 0); - R1.incorporate({0, 1, 6}, 1); - auto z1 = R1.get_cluster_assignment({0, 1, 3}); - for (int x : z1) { - printf("%d,", x); - } - auto z2 = R1.get_cluster_assignment_gibbs({0, 1, 3}, D2, 1, 191); - printf("\n"); - for (int x : z2) { - printf("%d,", x); - } - printf("\n"); - - Relation R2("R1", "bigram", {&D2, &D3}, &prng); - printf("arity %ld\n", R1.domains.size()); - R2.incorporate({1, 3}, "cat"); - R2.incorporate({1, 2}, "dog"); - R2.incorporate({1, 4}, "catt"); - R2.incorporate({2, 6}, "fish"); - - double lpg = R1.logp_gibbs_approx(D1, 0, 1); - printf("logp gibbs %f\n", lpg); - lpg = R1.logp_gibbs_approx(D1, 0, 0); - printf("logp gibbs %f\n", lpg); - lpg = R1.logp_gibbs_approx(D1, 0, 10); - printf("logp gibbs %f\n", lpg); - lpg = R2.logp_gibbs_approx(D2, 2, 0); - printf("logp gibbs %f\n", lpg); - - printf("calling set_cluster_assignment_gibbs\n"); - R1.set_cluster_assignment_gibbs(D1, 0, 1); - R2.set_cluster_assignment_gibbs(D3, 3, 1); - printf("new cluster %d\n", D1.get_cluster_assignment(0)); - D1.set_cluster_assignment_gibbs(0, 1); - - printf("%lu\n", R1.data.size()); - // R1.unincorporate({0, 1, 3}); - printf("%lu\n", R1.data.size()); - printf("== HASHING UTIL == \n"); std::unordered_map, int, VectorIntHash> map_int; map_int[{1, 2}] = 7;