Skip to content

Commit

Permalink
Upgrade unordered_maps to maps in crp so we can do O(1) max key
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasColthurst committed Oct 2, 2024
1 parent e090b92 commit 2303a24
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 14 deletions.
11 changes: 6 additions & 5 deletions cxx/distributions/crp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ double CRP::logp_score() const {

int CRP::max_table() const {
if (N == 0) {
return 0;
return -1;
}
return tables.rbegin()->first;
}

std::unordered_map<int, double> CRP::tables_weights() const {
std::unordered_map<int, double> dist;
std::map<int, double> CRP::tables_weights() const {
std::map<int, double> dist;
if (N == 0) {
dist[0] = 1;
return dist;
Expand All @@ -89,14 +89,15 @@ std::unordered_map<int, double> CRP::tables_weights() const {
return dist;
}

std::unordered_map<int, double> CRP::tables_weights_gibbs(int table) const {
std::map<int, double> CRP::tables_weights_gibbs(int table) const {
assert(N > 0);
assert(tables.contains(table));
auto dist = tables_weights();
--dist.at(table);
if (dist.at(table) == 0) {
dist.at(table) = alpha;
dist.erase(max_table());
int t_max = dist.rbegin()->first;
dist.erase(t_max);
}
return dist;
}
Expand Down
5 changes: 3 additions & 2 deletions cxx/distributions/crp.hh
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ class CRP {

double logp_score() const;

// Returns the highest table entry in tables, or -1 if tables is empty.
int max_table() const;

std::unordered_map<int, double> tables_weights() const;
std::map<int, double> tables_weights() const;

std::unordered_map<int, double> tables_weights_gibbs(int table) const;
std::map<int, double> tables_weights_gibbs(int table) const;

void transition_alpha(std::mt19937* prng);
};
6 changes: 3 additions & 3 deletions cxx/distributions/crp_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace bm = boost::math;
BOOST_AUTO_TEST_CASE(test_simple) {
CRP crp;

BOOST_TEST(crp.max_table() == 0);
BOOST_TEST(crp.max_table() == -1);

T_item cat = 1;
T_item dog = 2;
Expand Down Expand Up @@ -64,15 +64,15 @@ BOOST_AUTO_TEST_CASE(test_simple) {
BOOST_TEST(crp.tables.at(3).contains(dog));

// Table weights are as expected.
std::unordered_map<int, double> tw = crp.tables_weights();
std::map<int, double> tw = crp.tables_weights();
BOOST_TEST(tw.size() == 4); // Three populated tables and one new one.
BOOST_TEST(tw[0] == crp.tables.at(0).size());
BOOST_TEST(tw[1] == crp.tables.at(1).size());
BOOST_TEST(tw[3] == crp.tables.at(3).size());
BOOST_TEST(tw[4] == crp.alpha);

// Table weights gibbs is as expected.
std::unordered_map<int, double> twg = crp.tables_weights_gibbs(1);
std::map<int, double> twg = crp.tables_weights_gibbs(1);
BOOST_TEST(tw[0] == twg[0]);
BOOST_TEST(tw[1] == twg[1] + 1.);
BOOST_TEST(tw[3] == twg[3]);
Expand Down
5 changes: 3 additions & 2 deletions cxx/domain.hh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#pragma once
#include <cassert>
#include <map>
#include <string>
#include <unordered_set>

Expand Down Expand Up @@ -44,10 +45,10 @@ class Domain {
crp.unincorporate(item);
crp.incorporate(item, table);
}
std::unordered_map<int, double> tables_weights() const {
std::map<int, double> tables_weights() const {
return crp.tables_weights();
}
std::unordered_map<int, double> tables_weights_gibbs(
std::map<int, double> tables_weights_gibbs(
const T_item& item) const {
int table = get_cluster_assignment(item);
return crp.tables_weights_gibbs(table);
Expand Down
2 changes: 1 addition & 1 deletion cxx/gendb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ void GenDB::transition_reference(std::mt19937* prng,
std::get<ClassVar>(schema.classes.at(class_name).vars.at(ref_field).spec)
.class_name;
int init_refval = reference_values.at({class_name, ref_field, class_item});
std::unordered_map<int, double> crp_dist =
std::map<int, double> crp_dist =
domain_crps[ref_class].tables_weights_gibbs(init_refval);

// For each relation, get the indices (in the items vector) of the reference
Expand Down
2 changes: 1 addition & 1 deletion cxx/gendb_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,7 @@ BOOST_AUTO_TEST_CASE(test_unincorporate_reincorporate_new) {

// Now find the singleton value.
int refval = -1;
std::unordered_map<int, double> crp_dist =
std::map<int, double> crp_dist =
gendb.domain_crps[ref_class].tables_weights_gibbs(non_singleton_refval);
for (auto [t, w] : crp_dist) {
if (!gendb.domain_crps[ref_class].tables.contains(t)) {
Expand Down

0 comments on commit 2303a24

Please sign in to comment.