Skip to content

Commit

Permalink
Add untracked files
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasColthurst committed Jun 10, 2024
1 parent cabb0db commit 0180b5b
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 0 deletions.
87 changes: 87 additions & 0 deletions cxx/distributions/zero_mean_normal.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copyright 2024
// See LICENSE.txt

#include "zero_mean_normal.hh"

#include <cmath>
#include <numbers>

// Return log density of location-scaled T distribution with zero mean.
double log_t_distribution(const double& x,
const double& v,
const double& variance) {
// https://en.wikipedia.org/wiki/Student%27s_t-distribution#Density_and_first_two_moments
double v_shift = (v + 1.0) / 2.0;
return lgamma(v_shift)
- lgamma(v / 2.0)
- 0.5 * log(std::numbers::pi * v * variance)
- v_shift * log(1.0 + x * x / variance / v);
}

void ZeroMeanNormal::incorporate(const double& x) {
++N;
var += (x * x - var) / N;
}

void ZeroMeanNormal::unincorporate(const double& x) {
int old_N = N;
--N;
if (N == 0) {
var = 0.0;
return;
}
var = (var * old_N - x * x) / N;
}


double ZeroMeanNormal::logp(const double& x) const {
// Equation (119) of https://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf
double alpha_n = alpha + N / 2.0;
double beta_n = beta + var * N;
double t_variance = beta_n / alpha_n;
return log_t_distribution(x, 2.0 * alpha_n, t_variance);
}

double ZeroMeanNormal::logp_score() const {
// Marginal likelihood from Page 10 of
// https://j-zin.github.io/files/Conjugate_Bayesian_analysis_of_common_distributions.pdf
double alpha_n = alpha + N / 2.0;
return alpha * log(beta)
- lgamma(alpha)
- (N / 2.0) * log(2.0 * std::numbers::pi)
+ lgamma(alpha_n)
- alpha_n * log(2.0 * beta + 0.5 * var * N);
}

double ZeroMeanNormal::sample() {
double alpha_n = alpha + N / 2.0;
double beta_n = beta + var * N;
double t_variance = beta_n / alpha_n;
std::student_t_distribution<double> d(2.0 * alpha_n);
return d(*prng) * sqrt(t_variance);
}

#define ALPHA_GRID \
{ 1e-4, 1e-3, 1e-2, 1e-1, 1.0, 10.0, 100.0, 1000.0, 10000.0 }
#define BETA_GRID \
{ 1e-4, 1e-3, 1e-2, 1e-1, 1.0, 10.0, 100.0, 1000.0, 10000.0 }

void ZeroMeanNormal::transition_hyperparameters() {
std::vector<double> logps;
std::vector<std::pair<double, double>> hypers;
for (double a : ALPHA_GRID) {
for (double b : BETA_GRID) {
alpha = a;
beta = b;
double lp = logp_score();
if (!std::isnan(lp)) {
logps.push_back(logp_score());
hypers.push_back(std::make_pair(a, b));
}
}
}

int i = sample_from_logps(logps, prng);
alpha = std::get<0>(hypers[i]);
beta = std::get<1>(hypers[i]);
}
46 changes: 46 additions & 0 deletions cxx/distributions/zero_mean_normal.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright 2024
// See LICENSE.txt

#pragma once
#include <random>
#include <tuple>
#include <variant>

#include "base.hh"
#include "util_math.hh"

// A normal distribution that is known to have zero mean.
// This class is not intended to be directly used as a data model, but
// rather to support the Gaussian emissions model.
class ZeroMeanNormal : public Distribution<double> {
public:
// We use an Inverse gamma conjugate prior, so our hyperparameters are
double alpha = 1.0;
double beta = 1.0;

// Sufficient statistics of observed data.
double var = 0.0;

std::mt19937* prng;

// ZeroMeanNormal does not take ownership of prng.
ZeroMeanNormal(std::mt19937* prng) { this->prng = prng; }

void incorporate(const double& x);

void unincorporate(const double& x);

void posterior_hypers(double* mprime, double* sprime) const;

double logp(const double& x) const;

double logp_score() const;

double sample();

void transition_hyperparameters();

// Disable copying.
ZeroMeanNormal& operator=(const ZeroMeanNormal&) = delete;
ZeroMeanNormal(const ZeroMeanNormal&) = delete;
};
105 changes: 105 additions & 0 deletions cxx/distributions/zero_mean_normal_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Apache License, Version 2.0, refer to LICENSE.txt

#define BOOST_TEST_MODULE test Normal

#include "distributions/zero_mean_normal.hh"

#include <boost/test/included/unit_test.hpp>

#include "util_math.hh"
namespace bm = boost::math;
namespace tt = boost::test_tools;

BOOST_AUTO_TEST_CASE(simple) {
std::mt19937 prng;
ZeroMeanNormal nd(&prng);

nd.incorporate(5.0);
nd.incorporate(-2.0);
BOOST_TEST(nd.N == 2);

nd.unincorporate(5.0);
nd.incorporate(7.0);
BOOST_TEST(nd.N == 2);

nd.unincorporate(-2.0);
BOOST_TEST(nd.N == 1);
}

BOOST_AUTO_TEST_CASE(no_nan_after_incorporate_unincorporate) {
std::mt19937 prng;
ZeroMeanNormal nd(&prng);

nd.incorporate(10.0);
nd.unincorporate(10.0);

BOOST_TEST(nd.N == 0);
BOOST_TEST(!std::isnan(nd.var));
}

BOOST_AUTO_TEST_CASE(logp_before_incorporate) {
std::mt19937 prng;
ZeroMeanNormal nd(&prng);

BOOST_TEST(nd.logp(6.0) == -5.4563792395895785, tt::tolerance(1e-6));
BOOST_TEST(nd.logp_score() == -0.69314718055994529, tt::tolerance(1e-6));

nd.incorporate(5.0);
nd.unincorporate(5.0);

BOOST_TEST(nd.N == 0);
BOOST_TEST(nd.logp(6.0) == -5.4563792395895785, tt::tolerance(1e-6));
BOOST_TEST(nd.logp_score() == -0.69314718055994529, tt::tolerance(1e-6));
}

BOOST_AUTO_TEST_CASE(sample) {
std::mt19937 prng;
ZeroMeanNormal nd(&prng);

for (int i = 0; i < 1000; ++i) {
nd.incorporate(42.0);
}

double s = nd.sample();

BOOST_TEST(std::abs(s) < 100.0);
}

BOOST_AUTO_TEST_CASE(incorporate_raises_logp) {
std::mt19937 prng;
ZeroMeanNormal nd(&prng);

double old_lp = nd.logp(10.0);
for (int i = 0; i < 8; ++i) {
nd.incorporate(10.0);
double lp = nd.logp(10.0);
BOOST_TEST(lp > old_lp);
old_lp = lp;
}
}

BOOST_AUTO_TEST_CASE(prior_prefers_origin) {
std::mt19937 prng;
ZeroMeanNormal nd1(&prng), nd2(&prng);

for (int i = 0; i < 100; ++i) {
nd1.incorporate(0.0);
nd2.incorporate(50.0);
}

BOOST_TEST(nd1.logp_score() > nd2.logp_score());
}

BOOST_AUTO_TEST_CASE(transition_hyperparameters) {
std::mt19937 prng;
ZeroMeanNormal nd(&prng);

nd.transition_hyperparameters();

for (int i = 0; i < 100; ++i) {
nd.incorporate(5.0);
}

nd.transition_hyperparameters();
BOOST_TEST(nd.beta > 1.0);
}

0 comments on commit 0180b5b

Please sign in to comment.