Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bijectors #16

Merged
merged 7 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ lsbi: Linear Simulation Based Inference
=======================================
:lsbi: Linear Simulation Based Inference
:Author: Will Handley & David Yallup
:Version: 0.7.0
:Version: 0.8.0
:Homepage: https://github.com/handley-lab/lsbi
:Documentation: http://lsbi.readthedocs.io/

Expand Down
9 changes: 9 additions & 0 deletions docs/source/lsbi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,12 @@ lsbi.stats module
:undoc-members:


lsbi.utils module
-----------------

.. automodule:: lsbi.utils
:members:
:undoc-members:
:show-inheritance:


2 changes: 1 addition & 1 deletion lsbi/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.7.0'
__version__ = '0.8.0'
8 changes: 2 additions & 6 deletions lsbi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,8 @@
import numpy as np
from lsbi.stats import (mixture_multivariate_normal,
multivariate_normal)
from numpy.linalg import solve, inv, slogdet


def logdet(A):
"""log(abs(det(A)))."""
return slogdet(A)[1]
from numpy.linalg import solve, inv
from lsbi.utils import logdet


class LinearModel(object):
Expand Down
86 changes: 86 additions & 0 deletions lsbi/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import numpy as np
import scipy.stats
from scipy.stats._multivariate import multivariate_normal_frozen
from scipy.special import logsumexp, erf
from numpy.linalg import inv
from lsbi.utils import bisect


class multivariate_normal(multivariate_normal_frozen): # noqa: D101
Expand Down Expand Up @@ -43,6 +45,33 @@ def _bar(self, indices):
k[indices] = False
return k

def bijector(self, x, inverse=False):
"""Bijector between U([0, 1])^d and the distribution.

- x in [0, 1]^d is the hypercube space.
- theta in R^d is the physical space.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this docstring could clarify as this method is valid on likelihood or posterior/prior what physical space is. It feels most natural to define this for parameter space (theta doubly suggesting this) transformations or some comment on it's dual usage (if it is intended/makes sense to use on data distributions too)


Computes the transformation from x to theta or theta to x depending on
the value of inverse.

Parameters
----------
x : array_like, shape (..., d)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this want/benefit stricter shape checking? I note for example if I have a 100D data likelihood I can

model.likelihood(theta).bijector(np.random.rand(1)[...,None])

i.e. I'm passing something of shape (1,1), rather than (100,) or (1,100) and it returns a valid data draw, but I'm not sure what this actually is!

if inverse: x is theta
else: x is x
inverse : bool, optional, default=False
If True: compute the inverse transformation from physical to
hypercube space.
"""
L = np.linalg.cholesky(self.cov)
if inverse:
Linv = inv(L)
y = np.einsum('ij,...j->...i', Linv, x-self.mean)
return scipy.stats.norm.cdf(y)
else:
y = scipy.stats.norm.ppf(x)
return self.mean + np.einsum('ij,...j->...i', L, y)


class mixture_multivariate_normal(object):
"""Mixture of multivariate normal distributions.
Expand Down Expand Up @@ -136,3 +165,60 @@ def _bar(self, indices):
k = np.ones(self.means.shape[-1], dtype=bool)
k[indices] = False
return k

def bijector(self, x, inverse=False):
"""Bijector between U([0, 1])^d and the distribution.

- x in [0, 1]^d is the hypercube space.
- theta in R^d is the physical space.

Computes the transformation from x to theta or theta to x depending on
the value of inverse.

Parameters
----------
x : array_like, shape (..., d)
if inverse: x is theta
else: x is x
inverse : bool, optional, default=False
If True: compute the inverse transformation from physical to
hypercube space.
"""
theta = np.empty_like(x)
if inverse:
theta[:] = x
x = np.empty_like(x)

for i in range(x.shape[-1]):
m = self.means[..., :, i] + np.einsum('ia,iab,...ib->...i',
self.covs[:, i, :i],
inv(self.covs[:, :i, :i]),
theta[..., None, :i]
- self.means[:, :i])
c = self.covs[:, i, i] - np.einsum('ia,iab,ib->i',
self.covs[:, i, :i],
inv(self.covs[:, :i, :i]),
self.covs[:, i, :i])
dist = mixture_multivariate_normal(self.means[:, :i],
self.covs[:, :i, :i],
self.logA)
logA = (self.logA + dist.logpdf(theta[..., :i], reduce=False)
- dist.logpdf(theta[..., :i])[..., None])
A = np.exp(logA - logsumexp(logA, axis=-1)[..., None])

def f(t):
return (A * 0.5 * (1 + erf((t[..., None] - m)/np.sqrt(2 * c)))
).sum(axis=-1) - y

if inverse:
y = 0
x[..., i] = f(theta[..., i])
else:
y = x[..., i]
a = (m - 10 * np.sqrt(c)).min(axis=-1)
b = (m + 10 * np.sqrt(c)).max(axis=-1)
theta[..., i] = bisect(f, a, b)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would benefit from some inline comments/ more expanded docstring explaining what is going on here as an additional moving part.

if inverse:
return x
else:
return theta
57 changes: 57 additions & 0 deletions lsbi/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Utility functions for lsbi."""
import numpy as np


def logdet(A):
"""log(abs(det(A)))."""
return np.linalg.slogdet(A)[1]


def quantise(f, x, tol=1e-8):
"""Quantise f(x) to zero within tolerance tol."""
y = np.atleast_1d(f(x))
return np.where(np.abs(y) < tol, 0, y)


def bisect(f, a, b, args=(), tol=1e-8):
"""Vectorised simple bisection search.

The shape of the output is the broadcasted shape of a and b.

Parameters
----------
f : callable
Function to find the root of.
a : array_like
Lower bound of the search interval.
b : array_like
Upper bound of the search interval.
args : tuple, optional
Extra arguments to `f`.
tol : float, optional
(absolute) tolerance of the solution

Returns
-------
x : ndarray
Solution to the equation f(x) = 0.
"""
a = np.array(a)
b = np.array(b)
while np.abs(a-b).max() > tol:
fa = quantise(f, a, tol)
fb = quantise(f, b, tol)
a = np.where(fb == 0, b, a)
b = np.where(fa == 0, a, b)

if np.any(fa*fb > 0):
raise ValueError("f(a) and f(b) must have opposite signs")
q = (a+b)/2
fq = quantise(f, q, tol)

a = np.where(fq == 0, q, a)
a = np.where(fa * fq > 0, q, a)

b = np.where(fq == 0, q, b)
b = np.where(fb * fq > 0, q, b)
return (a+b)/2
66 changes: 66 additions & 0 deletions tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,72 @@ def test_mixture_multivariate_normal(k, d):
assert mvns[0].logpdf(x).shape == mixture.logpdf(x).shape


def test_mixture_multivariate_normal_bijector():
k = 4
d = 10
covs = scipy.stats.wishart.rvs(d, np.eye(d), size=k)
means = np.random.randn(k, d)
logA = np.log(scipy.stats.dirichlet.rvs(np.ones(k))[0])
model = mixture_multivariate_normal(means, covs, logA)

# Test inversion
x = np.random.rand(1000, d)
theta = model.bijector(x)
assert_allclose(model.bijector(theta, inverse=True), x, atol=1e-6)

# Test sampling
samples = model.rvs(1000)
for i in range(d):
p = scipy.stats.kstest(theta[:, i], samples[:, i]).pvalue
assert p > 1e-5

p = scipy.stats.kstest(model.logpdf(samples), model.logpdf(theta)).pvalue
assert p > 1e-5

# Test shapes
x = np.random.rand(d)
theta = model.bijector(x)
assert theta.shape == x.shape
assert model.bijector(theta, inverse=True).shape == x.shape

x = np.random.rand(3, 4, d)
theta = model.bijector(x)
assert theta.shape == x.shape
assert model.bijector(theta, inverse=True).shape == x.shape


def test_multivariate_normal_bijector():
d = 10
cov = scipy.stats.wishart.rvs(d, np.eye(d))
mean = np.random.randn(d)
model = multivariate_normal(mean, cov)

# Test inversion
x = np.random.rand(1000, d)
theta = model.bijector(x)
assert_allclose(model.bijector(theta, inverse=True), x, atol=1e-6)

# Test sampling
samples = model.rvs(1000)
for i in range(d):
p = scipy.stats.kstest(theta[:, i], samples[:, i]).pvalue
assert p > 1e-5

p = scipy.stats.kstest(model.logpdf(samples), model.logpdf(theta)).pvalue
assert p > 1e-5

# Test shapes
x = np.random.rand(d)
theta = model.bijector(x)
assert theta.shape == x.shape
assert model.bijector(theta, inverse=True).shape == x.shape

x = np.random.rand(3, 4, d)
theta = model.bijector(x)
assert theta.shape == x.shape
assert model.bijector(theta, inverse=True).shape == x.shape


def test_marginalise_condition_multivariate_normal():
d = 5
mean = np.random.randn(d)
Expand Down
17 changes: 17 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from lsbi.utils import bisect
from numpy.testing import assert_allclose
import pytest


def test_bisect():
def f(x):
return x-5
assert bisect(f, 0, 10) == 5

with pytest.raises(ValueError):
bisect(f, 0, 4)

def f(x):
return x - [1, 2]

assert_allclose(bisect(f, 0, 10), [1, 2])