diff --git a/cxx/distributions/crp.cc b/cxx/distributions/crp.cc index 9912953..200c423 100644 --- a/cxx/distributions/crp.cc +++ b/cxx/distributions/crp.cc @@ -71,13 +71,13 @@ double CRP::logp_score() const { int CRP::max_table() const { if (N == 0) { - return 0; + return -1; } return tables.rbegin()->first; } -std::unordered_map CRP::tables_weights() const { - std::unordered_map dist; +std::map CRP::tables_weights() const { + std::map dist; if (N == 0) { dist[0] = 1; return dist; @@ -89,14 +89,15 @@ std::unordered_map CRP::tables_weights() const { return dist; } -std::unordered_map CRP::tables_weights_gibbs(int table) const { +std::map CRP::tables_weights_gibbs(int table) const { assert(N > 0); assert(tables.contains(table)); auto dist = tables_weights(); --dist.at(table); if (dist.at(table) == 0) { dist.at(table) = alpha; - dist.erase(max_table()); + int t_max = dist.rbegin()->first; + dist.erase(t_max); } return dist; } diff --git a/cxx/distributions/crp.hh b/cxx/distributions/crp.hh index 4fd283d..beed619 100644 --- a/cxx/distributions/crp.hh +++ b/cxx/distributions/crp.hh @@ -32,11 +32,12 @@ class CRP { double logp_score() const; + // Returns the highest table entry in tables, or -1 if tables is empty. int max_table() const; - std::unordered_map tables_weights() const; + std::map tables_weights() const; - std::unordered_map tables_weights_gibbs(int table) const; + std::map tables_weights_gibbs(int table) const; void transition_alpha(std::mt19937* prng); }; diff --git a/cxx/distributions/crp_test.cc b/cxx/distributions/crp_test.cc index f6db7f5..afa37c7 100644 --- a/cxx/distributions/crp_test.cc +++ b/cxx/distributions/crp_test.cc @@ -16,7 +16,7 @@ namespace bm = boost::math; BOOST_AUTO_TEST_CASE(test_simple) { CRP crp; - BOOST_TEST(crp.max_table() == 0); + BOOST_TEST(crp.max_table() == -1); T_item cat = 1; T_item dog = 2; @@ -64,7 +64,7 @@ BOOST_AUTO_TEST_CASE(test_simple) { BOOST_TEST(crp.tables.at(3).contains(dog)); // Table weights are as expected. - std::unordered_map tw = crp.tables_weights(); + std::map tw = crp.tables_weights(); BOOST_TEST(tw.size() == 4); // Three populated tables and one new one. BOOST_TEST(tw[0] == crp.tables.at(0).size()); BOOST_TEST(tw[1] == crp.tables.at(1).size()); @@ -72,7 +72,7 @@ BOOST_AUTO_TEST_CASE(test_simple) { BOOST_TEST(tw[4] == crp.alpha); // Table weights gibbs is as expected. - std::unordered_map twg = crp.tables_weights_gibbs(1); + std::map twg = crp.tables_weights_gibbs(1); BOOST_TEST(tw[0] == twg[0]); BOOST_TEST(tw[1] == twg[1] + 1.); BOOST_TEST(tw[3] == twg[3]); diff --git a/cxx/domain.hh b/cxx/domain.hh index c45918d..395a562 100644 --- a/cxx/domain.hh +++ b/cxx/domain.hh @@ -3,6 +3,7 @@ #pragma once #include +#include #include #include @@ -44,10 +45,10 @@ class Domain { crp.unincorporate(item); crp.incorporate(item, table); } - std::unordered_map tables_weights() const { + std::map tables_weights() const { return crp.tables_weights(); } - std::unordered_map tables_weights_gibbs( + std::map tables_weights_gibbs( const T_item& item) const { int table = get_cluster_assignment(item); return crp.tables_weights_gibbs(table); diff --git a/cxx/gendb.cc b/cxx/gendb.cc index 62e3971..cb99298 100644 --- a/cxx/gendb.cc +++ b/cxx/gendb.cc @@ -582,7 +582,7 @@ void GenDB::transition_reference(std::mt19937* prng, std::get(schema.classes.at(class_name).vars.at(ref_field).spec) .class_name; int init_refval = reference_values.at({class_name, ref_field, class_item}); - std::unordered_map crp_dist = + std::map crp_dist = domain_crps[ref_class].tables_weights_gibbs(init_refval); // For each relation, get the indices (in the items vector) of the reference diff --git a/cxx/gendb_test.cc b/cxx/gendb_test.cc index 1aad1ab..f070348 100644 --- a/cxx/gendb_test.cc +++ b/cxx/gendb_test.cc @@ -920,7 +920,7 @@ BOOST_AUTO_TEST_CASE(test_unincorporate_reincorporate_new) { // Now find the singleton value. int refval = -1; - std::unordered_map crp_dist = + std::map crp_dist = gendb.domain_crps[ref_class].tables_weights_gibbs(non_singleton_refval); for (auto [t, w] : crp_dist) { if (!gendb.domain_crps[ref_class].tables.contains(t)) {