Skip to content

Commit

Permalink
Move the Chinese Restaurant Process into the distributions directory …
Browse files Browse the repository at this point in the history
…and add tests.
  • Loading branch information
emilyfertig committed Jun 8, 2024
1 parent 911393b commit b0917fd
Show file tree
Hide file tree
Showing 7 changed files with 311 additions and 173 deletions.
21 changes: 21 additions & 0 deletions cxx/distributions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ cc_library(
":base",
":beta_bernoulli",
":bigram",
":crp",
":dirichlet_categorical",
":normal",
],
Expand Down Expand Up @@ -53,6 +54,17 @@ cc_library(
],
)

cc_library(
name = "crp",
srcs = ["crp.cc"],
hdrs = ["crp.hh"],
visibility = ["//:__subpackages__"],
deps = [
"//:headers",
"//:util_math",
],
)

cc_library(
name = "dirichlet_categorical",
srcs = ["dirichlet_categorical.cc"],
Expand Down Expand Up @@ -107,6 +119,15 @@ cc_test(
],
)

cc_test(
name = "crp_test",
srcs = ["crp_test.cc"],
deps = [
":crp",
"@boost//:test",
],
)

cc_test(
name = "dirichlet_categorical_test",
srcs = ["dirichlet_categorical_test.cc"],
Expand Down
114 changes: 114 additions & 0 deletions cxx/distributions/crp.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright 2024
// See LICENSE.txt

#include "distributions/crp.hh"

#include <cassert>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "util_math.hh"

void CRP::incorporate(const T_item& item, int table) {
assert(!assignments.contains(item));
if (!tables.contains(table)) {
tables[table] = std::unordered_set<T_item>();
}
tables.at(table).insert(item);
assignments[item] = table;
++N;
}

void CRP::unincorporate(const T_item& item) {
assert(assignments.contains(item));
int table = assignments.at(item);
tables.at(table).erase(item);
if (tables.at(table).empty()) {
tables.erase(table);
}
assignments.erase(item);
--N;
}

int CRP::sample() {
auto crp_dist = tables_weights();
std::vector<int> items(crp_dist.size());
std::vector<double> weights(crp_dist.size());
int i = 0;
for (const auto& [table, weight] : crp_dist) {
items[i] = table;
weights[i] = weight;
++i;
}
int idx = choice(weights, prng);
return items[idx];
}

double CRP::logp(int table) const {
auto dist = tables_weights();
if (!dist.contains(table)) {
return -std::numeric_limits<double>::infinity();
}
double numer = dist[table];
double denom = N + alpha;
return log(numer) - log(denom);
}

double CRP::logp_score() const {
double term1 = tables.size() * log(alpha);
double term2 = 0;
for (const auto& [table, customers] : tables) {
term2 += lgamma(customers.size());
}
double term3 = lgamma(alpha);
double term4 = lgamma(N + alpha);
double out = term1 + term2 + term3 - term4;
return out;
}

std::unordered_map<int, double> CRP::tables_weights() const {
std::unordered_map<int, double> dist;
if (N == 0) {
dist[0] = 1;
return dist;
}
int t_max = 0;
for (const auto& [table, customers] : tables) {
dist[table] = customers.size();
t_max = std::max(table, t_max);
}
dist[t_max + 1] = alpha;
return dist;
}

std::unordered_map<int, double> CRP::tables_weights_gibbs(int table) const {
assert(N > 0);
assert(tables.contains(table));
auto dist = tables_weights();
--dist.at(table);
if (dist.at(table) == 0) {
dist.at(table) = alpha;
int t_max = 0;
for (const auto& [table, weight] : dist) {
t_max = std::max(table, t_max);
}
dist.erase(t_max);
}
return dist;
}

void CRP::transition_alpha() {
if (N == 0) {
return;
}
std::vector<double> grid = log_linspace(1. / N, N + 1, 20, true);
std::vector<double> logps;
for (const double& g : grid) {
this->alpha = g;
double logp_g = logp_score();
logps.push_back(logp_g);
}
int idx = log_choice(logps, prng);
this->alpha = grid[idx];
}
39 changes: 39 additions & 0 deletions cxx/distributions/crp.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright 2024
// See LICENSE.txt

#pragma once
#include <random>
#include <unordered_map>
#include <unordered_set>


typedef int T_item;

// TODO(emilyaf): Make this a distribution subclass.
class CRP {
public:
double alpha = 1.; // concentration parameter
int N = 0; // number of customers
std::unordered_map<int, std::unordered_set<T_item>>
tables; // map from table id to set of customers
std::unordered_map<T_item, int> assignments; // map from customer to table id
std::mt19937* prng;

CRP(std::mt19937* prng) { this->prng = prng; }

void incorporate(const T_item& item, int table);

void unincorporate(const T_item& item);

int sample();

double logp(int table) const;

double logp_score() const;

std::unordered_map<int, double> tables_weights() const;

std::unordered_map<int, double> tables_weights_gibbs(int table) const;

void transition_alpha();
};
134 changes: 134 additions & 0 deletions cxx/distributions/crp_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Apache License, Version 2.0, refer to LICENSE.txt

#define BOOST_TEST_MODULE test CRP

#include "distributions/crp.hh"

#include <boost/test/included/unit_test.hpp>

#include "util_math.hh"
namespace bm = boost::math;
namespace tt = boost::test_tools;
namespace bm = boost::math;

BOOST_AUTO_TEST_CASE(test_simple) {
std::mt19937 prng;
CRP crp(&prng);

T_item cat = 1;
T_item dog = 2;
T_item fish = 3;
T_item bat = 4;
T_item hamster = 5;
T_item snake = 6;

crp.incorporate(bat, 3);
crp.incorporate(cat, 0);
crp.incorporate(dog, 0);
crp.incorporate(fish, 1);
crp.incorporate(hamster, 0);
crp.incorporate(snake, 1);
BOOST_TEST(crp.N == 6);

crp.unincorporate(cat);
BOOST_TEST(crp.N == 5);

crp.incorporate(cat, 3);
crp.unincorporate(dog);
BOOST_TEST(crp.N == 5);

crp.incorporate(dog, 3);

// Table assignments are as expected from `incorporate` calls.
BOOST_TEST(crp.assignments.at(bat) == 3);
BOOST_TEST(crp.assignments.at(cat) == 3);
BOOST_TEST(crp.assignments.at(dog) == 3);
BOOST_TEST(crp.assignments.at(fish) == 1);
BOOST_TEST(crp.assignments.at(hamster) == 0);
BOOST_TEST(crp.assignments.at(snake) == 1);

// Table contents are as expected.
BOOST_TEST(crp.tables.size() == 3);
BOOST_TEST(crp.tables.at(0).size() == 1);
BOOST_TEST(crp.tables.at(0).contains(hamster));
BOOST_TEST(crp.tables.at(1).size() == 2);
BOOST_TEST(crp.tables.at(1).contains(fish));
BOOST_TEST(crp.tables.at(1).contains(snake));
BOOST_TEST(crp.tables.at(3).size() == 3);
BOOST_TEST(crp.tables.at(3).contains(bat));
BOOST_TEST(crp.tables.at(3).contains(cat));
BOOST_TEST(crp.tables.at(3).contains(dog));

// Table weights are as expected.
std::unordered_map<int, double> tw = crp.tables_weights();
BOOST_TEST(tw.size() == 4); // Three populated tables and one new one.
BOOST_TEST(tw[0] == crp.tables.at(0).size());
BOOST_TEST(tw[1] == crp.tables.at(1).size());
BOOST_TEST(tw[3] == crp.tables.at(3).size());
BOOST_TEST(tw[4] == crp.alpha);

// Table weights gibbs is as expected.
std::unordered_map<int, double> twg = crp.tables_weights_gibbs(1);
BOOST_TEST(tw[0] == twg[0]);
BOOST_TEST(tw[1] == twg[1] + 1.);
BOOST_TEST(tw[3] == twg[3]);
BOOST_TEST(tw[4] == twg[4]);

// We expect that this is log(table_size) - log(N + alpha).
BOOST_TEST(crp.logp(1) == log(2. / (crp.alpha + crp.N)), tt::tolerance(1e-6));
BOOST_TEST(crp.logp(0) == log(1. / (crp.alpha + crp.N)), tt::tolerance(1e-6));
}

BOOST_AUTO_TEST_CASE(test_log_prob) {
std::mt19937 prng;
CRP crp(&prng);

T_item desk = 1;
T_item chair = 2;
T_item bureau = 3;
T_item lamp = 4;
T_item sofa = 5;

crp.incorporate(desk, 3);
double logp_score0 = crp.logp_score();
// Only one configuration for a single item, so p_score==1.
BOOST_TEST(logp_score0 == 0., tt::tolerance(1e-9));

double log_cond1 = log(crp.alpha / (1. + crp.alpha)); // New cluster.
crp.incorporate(chair, 2);
double logp_score1 = crp.logp_score();
// Successive log scores should equal the sum of the previous log score
// and the conditional log prob of the next observation incorporated.
BOOST_TEST(logp_score1 == logp_score0 + log_cond1, tt::tolerance(1e-9));

double log_cond2 = crp.logp(3);
crp.incorporate(bureau, 3);
double logp_score2 = crp.logp_score();
BOOST_TEST(logp_score2 == logp_score1 + log_cond2, tt::tolerance(1e-9));

double log_cond3 = crp.logp(3);
crp.incorporate(lamp, 3);
double logp_score3 = crp.logp_score();
BOOST_TEST(logp_score3 == logp_score2 + log_cond3, tt::tolerance(1e-9));

double log_cond4 = crp.logp(2);
crp.incorporate(sofa, 2);
double logp_score4 = crp.logp_score();
BOOST_TEST(logp_score4 == logp_score3 + log_cond4, tt::tolerance(1e-9));
}

BOOST_AUTO_TEST_CASE(test_transition_hyperparameters) {
std::mt19937 prng;
CRP crp(&prng);

crp.transition_alpha();
double old_alpha = crp.alpha;

for (int i = 0; i < 100; ++i) {
crp.incorporate(i, 0);
}

crp.transition_alpha();
// Expect that since all items are at one table, the new alpha is low.
BOOST_TEST(crp.alpha < old_alpha);
}
Loading

0 comments on commit b0917fd

Please sign in to comment.