forked from probsys/hierarchical-irm
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cabb0db
commit 0180b5b
Showing
3 changed files
with
238 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
// Copyright 2024 | ||
// See LICENSE.txt | ||
|
||
#include "zero_mean_normal.hh" | ||
|
||
#include <cmath> | ||
#include <numbers> | ||
|
||
// Return log density of location-scaled T distribution with zero mean. | ||
double log_t_distribution(const double& x, | ||
const double& v, | ||
const double& variance) { | ||
// https://en.wikipedia.org/wiki/Student%27s_t-distribution#Density_and_first_two_moments | ||
double v_shift = (v + 1.0) / 2.0; | ||
return lgamma(v_shift) | ||
- lgamma(v / 2.0) | ||
- 0.5 * log(std::numbers::pi * v * variance) | ||
- v_shift * log(1.0 + x * x / variance / v); | ||
} | ||
|
||
void ZeroMeanNormal::incorporate(const double& x) { | ||
++N; | ||
var += (x * x - var) / N; | ||
} | ||
|
||
void ZeroMeanNormal::unincorporate(const double& x) { | ||
int old_N = N; | ||
--N; | ||
if (N == 0) { | ||
var = 0.0; | ||
return; | ||
} | ||
var = (var * old_N - x * x) / N; | ||
} | ||
|
||
|
||
double ZeroMeanNormal::logp(const double& x) const { | ||
// Equation (119) of https://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf | ||
double alpha_n = alpha + N / 2.0; | ||
double beta_n = beta + var * N; | ||
double t_variance = beta_n / alpha_n; | ||
return log_t_distribution(x, 2.0 * alpha_n, t_variance); | ||
} | ||
|
||
double ZeroMeanNormal::logp_score() const { | ||
// Marginal likelihood from Page 10 of | ||
// https://j-zin.github.io/files/Conjugate_Bayesian_analysis_of_common_distributions.pdf | ||
double alpha_n = alpha + N / 2.0; | ||
return alpha * log(beta) | ||
- lgamma(alpha) | ||
- (N / 2.0) * log(2.0 * std::numbers::pi) | ||
+ lgamma(alpha_n) | ||
- alpha_n * log(2.0 * beta + 0.5 * var * N); | ||
} | ||
|
||
double ZeroMeanNormal::sample() { | ||
double alpha_n = alpha + N / 2.0; | ||
double beta_n = beta + var * N; | ||
double t_variance = beta_n / alpha_n; | ||
std::student_t_distribution<double> d(2.0 * alpha_n); | ||
return d(*prng) * sqrt(t_variance); | ||
} | ||
|
||
#define ALPHA_GRID \ | ||
{ 1e-4, 1e-3, 1e-2, 1e-1, 1.0, 10.0, 100.0, 1000.0, 10000.0 } | ||
#define BETA_GRID \ | ||
{ 1e-4, 1e-3, 1e-2, 1e-1, 1.0, 10.0, 100.0, 1000.0, 10000.0 } | ||
|
||
void ZeroMeanNormal::transition_hyperparameters() { | ||
std::vector<double> logps; | ||
std::vector<std::pair<double, double>> hypers; | ||
for (double a : ALPHA_GRID) { | ||
for (double b : BETA_GRID) { | ||
alpha = a; | ||
beta = b; | ||
double lp = logp_score(); | ||
if (!std::isnan(lp)) { | ||
logps.push_back(logp_score()); | ||
hypers.push_back(std::make_pair(a, b)); | ||
} | ||
} | ||
} | ||
|
||
int i = sample_from_logps(logps, prng); | ||
alpha = std::get<0>(hypers[i]); | ||
beta = std::get<1>(hypers[i]); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
// Copyright 2024 | ||
// See LICENSE.txt | ||
|
||
#pragma once | ||
#include <random> | ||
#include <tuple> | ||
#include <variant> | ||
|
||
#include "base.hh" | ||
#include "util_math.hh" | ||
|
||
// A normal distribution that is known to have zero mean. | ||
// This class is not intended to be directly used as a data model, but | ||
// rather to support the Gaussian emissions model. | ||
class ZeroMeanNormal : public Distribution<double> { | ||
public: | ||
// We use an Inverse gamma conjugate prior, so our hyperparameters are | ||
double alpha = 1.0; | ||
double beta = 1.0; | ||
|
||
// Sufficient statistics of observed data. | ||
double var = 0.0; | ||
|
||
std::mt19937* prng; | ||
|
||
// ZeroMeanNormal does not take ownership of prng. | ||
ZeroMeanNormal(std::mt19937* prng) { this->prng = prng; } | ||
|
||
void incorporate(const double& x); | ||
|
||
void unincorporate(const double& x); | ||
|
||
void posterior_hypers(double* mprime, double* sprime) const; | ||
|
||
double logp(const double& x) const; | ||
|
||
double logp_score() const; | ||
|
||
double sample(); | ||
|
||
void transition_hyperparameters(); | ||
|
||
// Disable copying. | ||
ZeroMeanNormal& operator=(const ZeroMeanNormal&) = delete; | ||
ZeroMeanNormal(const ZeroMeanNormal&) = delete; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
// Apache License, Version 2.0, refer to LICENSE.txt | ||
|
||
#define BOOST_TEST_MODULE test Normal | ||
|
||
#include "distributions/zero_mean_normal.hh" | ||
|
||
#include <boost/test/included/unit_test.hpp> | ||
|
||
#include "util_math.hh" | ||
namespace bm = boost::math; | ||
namespace tt = boost::test_tools; | ||
|
||
BOOST_AUTO_TEST_CASE(simple) { | ||
std::mt19937 prng; | ||
ZeroMeanNormal nd(&prng); | ||
|
||
nd.incorporate(5.0); | ||
nd.incorporate(-2.0); | ||
BOOST_TEST(nd.N == 2); | ||
|
||
nd.unincorporate(5.0); | ||
nd.incorporate(7.0); | ||
BOOST_TEST(nd.N == 2); | ||
|
||
nd.unincorporate(-2.0); | ||
BOOST_TEST(nd.N == 1); | ||
} | ||
|
||
BOOST_AUTO_TEST_CASE(no_nan_after_incorporate_unincorporate) { | ||
std::mt19937 prng; | ||
ZeroMeanNormal nd(&prng); | ||
|
||
nd.incorporate(10.0); | ||
nd.unincorporate(10.0); | ||
|
||
BOOST_TEST(nd.N == 0); | ||
BOOST_TEST(!std::isnan(nd.var)); | ||
} | ||
|
||
BOOST_AUTO_TEST_CASE(logp_before_incorporate) { | ||
std::mt19937 prng; | ||
ZeroMeanNormal nd(&prng); | ||
|
||
BOOST_TEST(nd.logp(6.0) == -5.4563792395895785, tt::tolerance(1e-6)); | ||
BOOST_TEST(nd.logp_score() == -0.69314718055994529, tt::tolerance(1e-6)); | ||
|
||
nd.incorporate(5.0); | ||
nd.unincorporate(5.0); | ||
|
||
BOOST_TEST(nd.N == 0); | ||
BOOST_TEST(nd.logp(6.0) == -5.4563792395895785, tt::tolerance(1e-6)); | ||
BOOST_TEST(nd.logp_score() == -0.69314718055994529, tt::tolerance(1e-6)); | ||
} | ||
|
||
BOOST_AUTO_TEST_CASE(sample) { | ||
std::mt19937 prng; | ||
ZeroMeanNormal nd(&prng); | ||
|
||
for (int i = 0; i < 1000; ++i) { | ||
nd.incorporate(42.0); | ||
} | ||
|
||
double s = nd.sample(); | ||
|
||
BOOST_TEST(std::abs(s) < 100.0); | ||
} | ||
|
||
BOOST_AUTO_TEST_CASE(incorporate_raises_logp) { | ||
std::mt19937 prng; | ||
ZeroMeanNormal nd(&prng); | ||
|
||
double old_lp = nd.logp(10.0); | ||
for (int i = 0; i < 8; ++i) { | ||
nd.incorporate(10.0); | ||
double lp = nd.logp(10.0); | ||
BOOST_TEST(lp > old_lp); | ||
old_lp = lp; | ||
} | ||
} | ||
|
||
BOOST_AUTO_TEST_CASE(prior_prefers_origin) { | ||
std::mt19937 prng; | ||
ZeroMeanNormal nd1(&prng), nd2(&prng); | ||
|
||
for (int i = 0; i < 100; ++i) { | ||
nd1.incorporate(0.0); | ||
nd2.incorporate(50.0); | ||
} | ||
|
||
BOOST_TEST(nd1.logp_score() > nd2.logp_score()); | ||
} | ||
|
||
BOOST_AUTO_TEST_CASE(transition_hyperparameters) { | ||
std::mt19937 prng; | ||
ZeroMeanNormal nd(&prng); | ||
|
||
nd.transition_hyperparameters(); | ||
|
||
for (int i = 0; i < 100; ++i) { | ||
nd.incorporate(5.0); | ||
} | ||
|
||
nd.transition_hyperparameters(); | ||
BOOST_TEST(nd.beta > 1.0); | ||
} |