-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial implementation of transition_hyperparameters #25
Changes from 10 commits
0b36ebc
a52e846
25b75fc
83995eb
9c975b8
51663db
0d259cf
5df8a60
f614f12
64d4735
bc3bd4f
638ef45
8a348c1
cdd11e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,57 +3,81 @@ | |
|
||
#pragma once | ||
#include <algorithm> | ||
#include <cassert> | ||
#include <random> | ||
|
||
#include "base.hh" | ||
#include "util_math.hh" | ||
|
||
#define ALPHA_GRID {1e-4, 1e-3, 1e-2, 1e-1, 1.0, 10.0, 100.0, 1000.0, 10000.0} | ||
|
||
class DirichletCategorical : public Distribution<double> { | ||
public: | ||
double alpha = 1; // hyperparameter (applies to all categories) | ||
std::vector<int> counts; // counts of observed categories | ||
int n; // Total number of observations. | ||
std::mt19937* prng; | ||
public: | ||
double alpha = 1; // hyperparameter (applies to all categories) | ||
std::vector<int> counts; // counts of observed categories | ||
int n; // Total number of observations. | ||
std::mt19937* prng; | ||
|
||
// DirichletCategorical does not take ownership of prng. | ||
DirichletCategorical(std::mt19937* prng, | ||
int k) { // k is number of categories | ||
this->prng = prng; | ||
counts = std::vector<int>(k, 0); | ||
n = 0; | ||
} | ||
void incorporate(const double& x) { | ||
assert(x >= 0 && x < counts.size()); | ||
counts[size_t(x)] += 1; | ||
++n; | ||
} | ||
void unincorporate(const double& x) { | ||
const size_t y = x; | ||
assert(y < counts.size()); | ||
counts[y] -= 1; | ||
--n; | ||
assert(0 <= counts[y]); | ||
assert(0 <= n); | ||
} | ||
double logp(const double& x) const { | ||
assert(x >= 0 && x < counts.size()); | ||
const double numer = log(alpha + counts[size_t(x)]); | ||
const double denom = log(n + alpha * counts.size()); | ||
return numer - denom; | ||
} | ||
double logp_score() const { | ||
const size_t k = counts.size(); | ||
const double a = alpha * k; | ||
const double lg = std::transform_reduce( | ||
counts.cbegin(), counts.cend(), 0, std::plus{}, | ||
[&](size_t y) -> double { return lgamma(y + alpha); }); | ||
return lgamma(a) - lgamma(a + n) + lg - k * lgamma(alpha); | ||
} | ||
double sample() { | ||
std::vector<double> weights(counts.size()); | ||
std::transform(counts.begin(), counts.end(), weights.begin(), | ||
[&](size_t y) -> double { return y + alpha; }); | ||
int idx = choice(weights, prng); | ||
return double(idx); | ||
} | ||
// DirichletCategorical does not take ownership of prng. | ||
DirichletCategorical(std::mt19937 *prng, int k) { // k is number of categories | ||
this->prng = prng; | ||
counts = std::vector<int>(k, 0); | ||
n = 0; | ||
} | ||
void incorporate(const double& x) { | ||
assert(x >= 0 && x < counts.size()); | ||
counts[size_t(x)] += 1; | ||
++n; | ||
} | ||
void unincorporate(const double& x) { | ||
const size_t y = x; | ||
assert(y < counts.size()); | ||
counts[y] -= 1; | ||
--n; | ||
assert(0 <= counts[y]); | ||
assert(0 <= n); | ||
} | ||
double logp(const double& x) const { | ||
assert(x >= 0 && x < counts.size()); | ||
const double numer = log(alpha + counts[size_t(x)]); | ||
const double denom = log(n + alpha * counts.size()); | ||
return numer - denom; | ||
} | ||
double logp_score() const { | ||
const size_t k = counts.size(); | ||
const double a = alpha * k; | ||
const double lg = std::transform_reduce( | ||
counts.cbegin(), | ||
counts.cend(), | ||
0, | ||
std::plus{}, | ||
[&](size_t y) -> double {return lgamma(y + alpha); } | ||
); | ||
return lgamma(a) - lgamma(a + n) + lg - k * lgamma(alpha); | ||
} | ||
double sample() { | ||
std::vector<double> weights(counts.size()); | ||
std::transform( | ||
counts.begin(), | ||
counts.end(), | ||
weights.begin(), | ||
[&](size_t y) -> double { return y + alpha; } | ||
); | ||
int idx = choice(weights, prng); | ||
return double(idx); | ||
} | ||
void transition_hyperparameters() { | ||
std::vector<double> logps; | ||
std::vector<double> alphas; | ||
// C++ doesn't yet allow range for-loops over existing variables. Sigh. | ||
for (double alphat : ALPHA_GRID) { | ||
alpha = alphat; | ||
double lp = logp_score(); | ||
if (!std::isnan(lp)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you have a sense of when nans are showing up? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. None in dirichlet_categorical, at least with the current unit test, but I did see them previously in BetaBernoulli's transition_hyperparamters. Huh. When I rerun that now, I no longer see any Nan's. I wonder if the above lbeta change fixed them. |
||
logps.push_back(logp_score()); | ||
alphas.push_back(alpha); | ||
} | ||
} | ||
int i = sample_from_logps(logps, prng); | ||
alpha = alphas[i]; | ||
} | ||
}; |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
// Apache License, Version 2.0, refer to LICENSE.txt | ||
|
||
#define BOOST_TEST_MODULE test DirichletCategorical | ||
|
||
#include <boost/test/included/unit_test.hpp> | ||
#include "distributions/dirichlet_categorical.hh" | ||
namespace tt = boost::test_tools; | ||
|
||
BOOST_AUTO_TEST_CASE(test_simple) | ||
{ | ||
std::mt19937 prng; | ||
DirichletCategorical dc(&prng, 10); | ||
|
||
for (int i = 0; i < 10; ++i) { | ||
dc.incorporate(i); | ||
} | ||
for (int i = 0; i < 10; i += 2) { | ||
dc.unincorporate(i); | ||
} | ||
|
||
BOOST_TEST(dc.logp(1) == -2.0149030205422647, tt::tolerance(1e-6)); | ||
BOOST_TEST(dc.logp_score() == -12.389393702657209, tt::tolerance(1e-6)); | ||
} | ||
|
||
BOOST_AUTO_TEST_CASE(test_transition_hyperparameters) | ||
{ | ||
std::mt19937 prng; | ||
DirichletCategorical dc(&prng, 10); | ||
|
||
dc.transition_hyperparameters(); | ||
|
||
for (int i = 0; i < 100; ++i) { | ||
dc.incorporate(i % 10); | ||
} | ||
|
||
dc.transition_hyperparameters(); | ||
BOOST_TEST(dc.alpha > 1.0); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our current declaration / implementation of
lbeta
inutil_math.hh
andutil_math.cc
declare the arguments asint
s. My compiler issued warnings that floating-point numbers were being coerced. I think now that we have hyperparameters that can be non-integer values, we'll want to change the argument types oflbeta
(I think the implementation should stay the same).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.