Skip to content

Commit

Permalink
Merge CRP tests from hirm_test into crp_test.
Browse files Browse the repository at this point in the history
  • Loading branch information
emilyfertig committed Jun 8, 2024
1 parent b0917fd commit 21c2600
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 106 deletions.
11 changes: 0 additions & 11 deletions cxx/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,6 @@ cc_library(
deps = [":headers"],
)

cc_test(
name = "hirm_test",
srcs = ["hirm_test.cc"],
deps = [
":headers",
"//distributions",
"@boost//:algorithm",
"@boost//:test",
],
)

cc_test(
name = "util_math_test",
srcs = ["util_math_test.cc"],
Expand Down
1 change: 1 addition & 0 deletions cxx/distributions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ cc_test(
srcs = ["crp_test.cc"],
deps = [
":crp",
"@boost//:algorithm",
"@boost//:test",
],
)
Expand Down
47 changes: 47 additions & 0 deletions cxx/distributions/crp_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include "distributions/crp.hh"

#include <boost/range/algorithm.hpp>
#include <boost/range/numeric.hpp>
#include <boost/test/included/unit_test.hpp>

#include "util_math.hh"
Expand Down Expand Up @@ -77,6 +79,10 @@ BOOST_AUTO_TEST_CASE(test_simple) {
// We expect that this is log(table_size) - log(N + alpha).
BOOST_TEST(crp.logp(1) == log(2. / (crp.alpha + crp.N)), tt::tolerance(1e-6));
BOOST_TEST(crp.logp(0) == log(1. / (crp.alpha + crp.N)), tt::tolerance(1e-6));
BOOST_TEST(crp.logp(3) == log(3. / (crp.alpha + crp.N)), tt::tolerance(1e-6));

// This should be log(alpha ^ 3 * 0! * 2! * 1! * 0! / (6!)) = -log(360)
BOOST_CHECK_CLOSE(crp.logp_score(), -log(360.), 1e-6);
}

BOOST_AUTO_TEST_CASE(test_log_prob) {
Expand Down Expand Up @@ -132,3 +138,44 @@ BOOST_AUTO_TEST_CASE(test_transition_hyperparameters) {
// Expect that since all items are at one table, the new alpha is low.
BOOST_TEST(crp.alpha < old_alpha);
}

BOOST_AUTO_TEST_CASE(test_crp_sample) {
std::mt19937 prng;
auto crp = CRP(&prng);
for (int i = 0; i < 10; ++i) {
crp.incorporate(i, 0);
}
for (int i = 10; i < 15; ++i) {
crp.incorporate(i, 1);
}

// We have the following set up, 10 in the first table, and 5 in the second
// table. This corresponds to a new customer having probability 10 / 16 for
// the first table, 5 / 16 for the second table, and 1 / 16 for the next
// table. Check that these frequencies are approximately matched.
std::map<int, int> table_count;
const int num_draws = 3000;
for (int i = 0; i < num_draws; ++i) {
++table_count[crp.sample()];
}

// Check that the count of 0's is close to 10/16 = 5/8.
// Because these are independent bernoulli draws, we check that we are within
// one standard deviation using the Binomial stddev.
BOOST_TEST(table_count[0] / static_cast<double>(num_draws) <=
5 / 8. + sqrt(5 / 8. * 3 / 8. / num_draws));
BOOST_TEST(table_count[0] / static_cast<double>(num_draws) >=
5 / 8. - sqrt(5 / 8. * 3 / 8. / num_draws));

// Check that the count of 1's is close to 5/16.
BOOST_TEST(table_count[1] / static_cast<double>(num_draws) <=
5 / 16. + sqrt(5 / 16. * 13 / 16. / num_draws));
BOOST_TEST(table_count[1] / static_cast<double>(num_draws) >=
5 / 16. - sqrt(5 / 16. * 13 / 16. / num_draws));

// Check that the count of 2's is close to 1/16.
BOOST_TEST(table_count[2] / static_cast<double>(num_draws) <=
1 / 16. + sqrt(1 / 16. * 15 / 16. / num_draws));
BOOST_TEST(table_count[2] / static_cast<double>(num_draws) >=
1 / 16. - sqrt(1 / 16. * 15 / 16. / num_draws));
}
95 changes: 0 additions & 95 deletions cxx/hirm_test.cc

This file was deleted.

0 comments on commit 21c2600

Please sign in to comment.