diff --git a/cxx/distributions/zero_mean_normal.cc b/cxx/distributions/zero_mean_normal.cc new file mode 100644 index 0000000..0a066a8 --- /dev/null +++ b/cxx/distributions/zero_mean_normal.cc @@ -0,0 +1,87 @@ +// Copyright 2024 +// See LICENSE.txt + +#include "zero_mean_normal.hh" + +#include +#include + +// Return log density of location-scaled T distribution with zero mean. +double log_t_distribution(const double& x, + const double& v, + const double& variance) { + // https://en.wikipedia.org/wiki/Student%27s_t-distribution#Density_and_first_two_moments + double v_shift = (v + 1.0) / 2.0; + return lgamma(v_shift) + - lgamma(v / 2.0) + - 0.5 * log(std::numbers::pi * v * variance) + - v_shift * log(1.0 + x * x / variance / v); +} + +void ZeroMeanNormal::incorporate(const double& x) { + ++N; + var += (x * x - var) / N; +} + +void ZeroMeanNormal::unincorporate(const double& x) { + int old_N = N; + --N; + if (N == 0) { + var = 0.0; + return; + } + var = (var * old_N - x * x) / N; +} + + +double ZeroMeanNormal::logp(const double& x) const { + // Equation (119) of https://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf + double alpha_n = alpha + N / 2.0; + double beta_n = beta + var * N; + double t_variance = beta_n / alpha_n; + return log_t_distribution(x, 2.0 * alpha_n, t_variance); +} + +double ZeroMeanNormal::logp_score() const { + // Marginal likelihood from Page 10 of + // https://j-zin.github.io/files/Conjugate_Bayesian_analysis_of_common_distributions.pdf + double alpha_n = alpha + N / 2.0; + return alpha * log(beta) + - lgamma(alpha) + - (N / 2.0) * log(2.0 * std::numbers::pi) + + lgamma(alpha_n) + - alpha_n * log(2.0 * beta + 0.5 * var * N); +} + +double ZeroMeanNormal::sample() { + double alpha_n = alpha + N / 2.0; + double beta_n = beta + var * N; + double t_variance = beta_n / alpha_n; + std::student_t_distribution d(2.0 * alpha_n); + return d(*prng) * sqrt(t_variance); +} + +#define ALPHA_GRID \ + { 1e-4, 1e-3, 1e-2, 1e-1, 1.0, 10.0, 100.0, 1000.0, 10000.0 } +#define BETA_GRID \ + { 1e-4, 1e-3, 1e-2, 1e-1, 1.0, 10.0, 100.0, 1000.0, 10000.0 } + +void ZeroMeanNormal::transition_hyperparameters() { + std::vector logps; + std::vector> hypers; + for (double a : ALPHA_GRID) { + for (double b : BETA_GRID) { + alpha = a; + beta = b; + double lp = logp_score(); + if (!std::isnan(lp)) { + logps.push_back(logp_score()); + hypers.push_back(std::make_pair(a, b)); + } + } + } + + int i = sample_from_logps(logps, prng); + alpha = std::get<0>(hypers[i]); + beta = std::get<1>(hypers[i]); +} diff --git a/cxx/distributions/zero_mean_normal.hh b/cxx/distributions/zero_mean_normal.hh new file mode 100644 index 0000000..7eb3e4a --- /dev/null +++ b/cxx/distributions/zero_mean_normal.hh @@ -0,0 +1,46 @@ +// Copyright 2024 +// See LICENSE.txt + +#pragma once +#include +#include +#include + +#include "base.hh" +#include "util_math.hh" + +// A normal distribution that is known to have zero mean. +// This class is not intended to be directly used as a data model, but +// rather to support the Gaussian emissions model. +class ZeroMeanNormal : public Distribution { + public: + // We use an Inverse gamma conjugate prior, so our hyperparameters are + double alpha = 1.0; + double beta = 1.0; + + // Sufficient statistics of observed data. + double var = 0.0; + + std::mt19937* prng; + + // ZeroMeanNormal does not take ownership of prng. + ZeroMeanNormal(std::mt19937* prng) { this->prng = prng; } + + void incorporate(const double& x); + + void unincorporate(const double& x); + + void posterior_hypers(double* mprime, double* sprime) const; + + double logp(const double& x) const; + + double logp_score() const; + + double sample(); + + void transition_hyperparameters(); + + // Disable copying. + ZeroMeanNormal& operator=(const ZeroMeanNormal&) = delete; + ZeroMeanNormal(const ZeroMeanNormal&) = delete; +}; diff --git a/cxx/distributions/zero_mean_normal_test.cc b/cxx/distributions/zero_mean_normal_test.cc new file mode 100644 index 0000000..88dd2dc --- /dev/null +++ b/cxx/distributions/zero_mean_normal_test.cc @@ -0,0 +1,105 @@ +// Apache License, Version 2.0, refer to LICENSE.txt + +#define BOOST_TEST_MODULE test Normal + +#include "distributions/zero_mean_normal.hh" + +#include + +#include "util_math.hh" +namespace bm = boost::math; +namespace tt = boost::test_tools; + +BOOST_AUTO_TEST_CASE(simple) { + std::mt19937 prng; + ZeroMeanNormal nd(&prng); + + nd.incorporate(5.0); + nd.incorporate(-2.0); + BOOST_TEST(nd.N == 2); + + nd.unincorporate(5.0); + nd.incorporate(7.0); + BOOST_TEST(nd.N == 2); + + nd.unincorporate(-2.0); + BOOST_TEST(nd.N == 1); +} + +BOOST_AUTO_TEST_CASE(no_nan_after_incorporate_unincorporate) { + std::mt19937 prng; + ZeroMeanNormal nd(&prng); + + nd.incorporate(10.0); + nd.unincorporate(10.0); + + BOOST_TEST(nd.N == 0); + BOOST_TEST(!std::isnan(nd.var)); +} + +BOOST_AUTO_TEST_CASE(logp_before_incorporate) { + std::mt19937 prng; + ZeroMeanNormal nd(&prng); + + BOOST_TEST(nd.logp(6.0) == -5.4563792395895785, tt::tolerance(1e-6)); + BOOST_TEST(nd.logp_score() == -0.69314718055994529, tt::tolerance(1e-6)); + + nd.incorporate(5.0); + nd.unincorporate(5.0); + + BOOST_TEST(nd.N == 0); + BOOST_TEST(nd.logp(6.0) == -5.4563792395895785, tt::tolerance(1e-6)); + BOOST_TEST(nd.logp_score() == -0.69314718055994529, tt::tolerance(1e-6)); +} + +BOOST_AUTO_TEST_CASE(sample) { + std::mt19937 prng; + ZeroMeanNormal nd(&prng); + + for (int i = 0; i < 1000; ++i) { + nd.incorporate(42.0); + } + + double s = nd.sample(); + + BOOST_TEST(std::abs(s) < 100.0); +} + +BOOST_AUTO_TEST_CASE(incorporate_raises_logp) { + std::mt19937 prng; + ZeroMeanNormal nd(&prng); + + double old_lp = nd.logp(10.0); + for (int i = 0; i < 8; ++i) { + nd.incorporate(10.0); + double lp = nd.logp(10.0); + BOOST_TEST(lp > old_lp); + old_lp = lp; + } +} + +BOOST_AUTO_TEST_CASE(prior_prefers_origin) { + std::mt19937 prng; + ZeroMeanNormal nd1(&prng), nd2(&prng); + + for (int i = 0; i < 100; ++i) { + nd1.incorporate(0.0); + nd2.incorporate(50.0); + } + + BOOST_TEST(nd1.logp_score() > nd2.logp_score()); +} + +BOOST_AUTO_TEST_CASE(transition_hyperparameters) { + std::mt19937 prng; + ZeroMeanNormal nd(&prng); + + nd.transition_hyperparameters(); + + for (int i = 0; i < 100; ++i) { + nd.incorporate(5.0); + } + + nd.transition_hyperparameters(); + BOOST_TEST(nd.beta > 1.0); +}