Skip to content

Commit

Permalink
SteinVI: New quadratic form kernel with Gauss Newton metric (#1953)
Browse files Browse the repository at this point in the history
* added rgn_kernel

* added rgnkernel test

* updated docs

* fixed lint

* revert params for stein_bnn example

* added upperbound <0.5.0 on jax and jaxlib

* added bound on jaxlib
  • Loading branch information
OlaRonning authored Jan 25, 2025
1 parent 5aca6cb commit f529fb0
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 24 deletions.
Binary file added docs/source/_static/img/examples/stein_bnn.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/source/_static/img/examples/stein_dmm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 3 additions & 1 deletion docs/source/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ The framework currently supports several kernels, including:
- `RandomFeatureKernel`
- `MixtureKernel`
- `GraphicalKernel`
- `RadialGaussNewtonKernel`


SteinVI based examples include:

Expand Down Expand Up @@ -80,7 +82,7 @@ SteinVI Kernels
.. autoclass:: numpyro.contrib.einstein.stein_kernels.MixtureKernel
.. autoclass:: numpyro.contrib.einstein.stein_kernels.GraphicalKernel
.. autoclass:: numpyro.contrib.einstein.stein_kernels.ProbabilityProductKernel

.. autoclass:: numpyro.contrib.einstein.stein_kernels.RadialGaussNewtonKernel

Stochastic Support
~~~~~~~~~~~~~~~~~~
Expand Down
6 changes: 5 additions & 1 deletion examples/stein_bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@
"""
Example: Bayesian Neural Network with SteinVI
=============================================
We demonstrate how to use SteinVI to predict housing prices using a BNN for the Boston Housing prices dataset
from the UCI regression benchmarks.
.. image:: ../_static/img/examples/stein_bnn.png
:align: center
:scale: 60%
"""

import argparse
Expand Down Expand Up @@ -119,7 +124,6 @@ def main(args):

rng_key, inf_key = random.split(inf_key)

# We find that SteinVI benefits from a small radius when inferring BNNs.
guide = AutoNormal(model)

stein = SteinVI(
Expand Down
55 changes: 53 additions & 2 deletions numpyro/contrib/einstein/stein_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@

import numpy as np

from jax import random
from jax import grad, numpy as jnp, random, vmap
from jax.lax import stop_gradient
import jax.numpy as jnp
import jax.scipy.linalg
import jax.scipy.stats

Expand Down Expand Up @@ -461,3 +460,55 @@ def kernel(x, y):
@property
def mode(self):
return self._mode


class RadialGaussNewtonKernel(SteinKernel):
r"""The Radial Gauss-Newton Kernel [1,2], also called the scaled Hessian kernel, is a scalar kernel defined as
.. math::
k(x,y) = \exp\left(-\frac{1}{2d}(x-y)^T M (x-y)\right),
where :math:`x,y \in R^d` are particles and :math:`M` is a metric matrix. :math:`M` approximates the expected
curvature of the log posterior using Hessian approximations :math:`A(x)`.
The matrix :math:`M` is computed using :math:`m` particles as follows [2, Eq. 19, p.5]:
.. math::
M= \frac{1}{m} \sum_{i=1}^m A(x_i)
with the Hessian approximation given by:
.. math::
A(x) = J(x) J(x)^T,
where :math:`J(x)` is the Jacobian of an ELBO at :math:`x`.
**References**:
1. Maken, Fahira Afzal, Fabio Ramos, and Lionel Ott. "Stein Particle Filter for Nonlinear,
**Non-Gaussian State Estimation."** IEEE Robotics and Automation Letters 7.2 (2022).
2. Detommaso, Gianluca, et al. "A Stein variational Newton method."
Advances in Neural Information Processing Systems 31 (2018).
"""

def __init__(self):
self._mode = "norm"

def compute(self, rng_key, particles, particle_info, loss_fn):
n, d = particles.shape

Jx = vmap(grad(loss_fn, argnums=1))(
random.split(rng_key, n), particles, jnp.arange(n)
)

def kernel(x, y):
dist = jnp.dot(stop_gradient(Jx), (x - y)) ** 2

kernel_res = jnp.exp(-dist.mean() / (2 * d))
return kernel_res

return kernel

@property
def mode(self):
return self._mode
19 changes: 11 additions & 8 deletions numpyro/contrib/einstein/steinvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _numel(shape):


class SteinVI:
"""Variational inference with Stein mixtures inference.
"""Variational inference with Stein mixtures inference [1].
**Example:**
Expand Down Expand Up @@ -95,11 +95,13 @@ class SteinVI:
:param static_kwargs: Static keyword arguments for the model and guide. These arguments cannot change
during inference.
**References:** (MLA style)
**References:**
1. Liu, Chang, et al. "Understanding and Accelerating Particle-Based Variational Inference."
1. Rønning, Ola, et al. "ELBOing Stein: Variational Bayes with Stein Mixture Inference."
arXiv preprint arXiv:2410.22948 (2024).
2. Liu, Chang, et al. "Understanding and Accelerating Particle-Based Variational Inference."
International Conference on Machine Learning. PMLR, 2019.
2. Wang, Dilin, and Qiang Liu. "Nonlinear Stein Variational Gradient Descent for Learning Diversified Mixture Models."
3. Wang, Dilin, and Qiang Liu. "Nonlinear Stein Variational Gradient Descent for Learning Diversified Mixture Models."
International Conference on Machine Learning. PMLR, 2019.
""" # noqa: E501

Expand Down Expand Up @@ -585,7 +587,7 @@ class SVGD(SteinVI):
:param Dict static_kwargs: Static keyword arguments for the model and guide. These arguments cannot
change during inference.
**References:** (MLA style)
**References:**
1. Liu, Qiang, and Dilin Wang. "Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm."
Advances in neural information processing systems 29 (2016).
Expand Down Expand Up @@ -678,7 +680,7 @@ class ASVGD(SVGD):
:param Dict static_kwargs: Static keyword arguments for the model and guide. These arguments cannot
change during inference.
**References:** (MLA style)
**References:**
1. D'Angelo, Francesco, and Vincent Fortuin. "Annealed Stein Variational Gradient Descent."
Third Symposium on Advances in Approximate Bayesian Inference, 2021.
Expand Down Expand Up @@ -712,8 +714,9 @@ def _cyclical_annealing(num_steps: int, num_cycles: int, trans_speed: int):
"""Cyclical annealing schedule as in eq. 4 of [1].
**References** (MLA)
1. D'Angelo, Francesco, and Vincent Fortuin. "Annealed Stein Variational Gradient Descent."
Third Symposium on Advances in Approximate Bayesian Inference, 2021.
1. D'Angelo, Francesco, and Vincent Fortuin. "Annealed Stein Variational Gradient Descent."
Third Symposium on Advances in Approximate Bayesian Inference, 2021.
:param num_steps: The total number of steps. Corresponds to $T$ in eq. 4 of [1].
:param num_cycles: The total number of cycles. Corresponds to $C$ in eq. 4 of [1].
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from setuptools import find_packages, setup

PROJECT_PATH = os.path.dirname(os.path.abspath(__file__))
_jax_version_constraints = ">=0.4.25"
_jaxlib_version_constraints = ">=0.4.25"
_jax_version_constraints = ">=0.4.25,<0.5.0"
_jaxlib_version_constraints = ">=0.4.25,<0.5.0"

# Find version
for line in open(os.path.join(PROJECT_PATH, "numpyro", "version.py")):
Expand Down
50 changes: 40 additions & 10 deletions test/contrib/einstein/test_stein_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@
LinearKernel,
MixtureKernel,
ProbabilityProductKernel,
RadialGaussNewtonKernel,
RandomFeatureKernel,
RBFKernel,
)
from numpyro.distributions import Normal
from numpyro.infer.autoguide import AutoNormal
from numpyro.optim import Adam

T = namedtuple("TestSteinKernel", ["kernel", "particle_info", "loss_fn", "kval"])
T = namedtuple(
"TestSteinKernel", ["name", "kernel", "particle_info", "loss_fn", "kval"]
)

PARTICLES = np.array([[1.0, 2.0], [10.0, 5.0], [7.0, 3.0], [2.0, -1]])

Expand All @@ -36,6 +39,7 @@ def MOCK_MODEL():

TEST_CASES = [
T(
"RBFKernel",
RBFKernel,
lambda d: {},
lambda x: x,
Expand Down Expand Up @@ -63,8 +67,15 @@ def MOCK_MODEL():
"matrix": np.array([[0.00490776, 0.0], [0.0, 0.00490776]]),
},
),
T(RandomFeatureKernel, lambda d: {}, lambda x: x, {"norm": 13.805723}),
T(
"RandomFeatureKernel",
RandomFeatureKernel,
lambda d: {},
lambda x: x,
{"norm": 13.805723},
),
T(
"IMQKernel",
IMQKernel,
lambda d: {},
lambda x: x,
Expand All @@ -84,6 +95,7 @@ def MOCK_MODEL():
},
),
T(
"LinearKernel",
LinearKernel,
lambda d: {},
lambda x: x,
Expand All @@ -96,6 +108,7 @@ def MOCK_MODEL():
},
),
T(
"MixtureKernel",
lambda mode: MixtureKernel(
mode=mode,
ws=np.array([0.2, 0.8]),
Expand All @@ -107,6 +120,7 @@ def MOCK_MODEL():
{"matrix": np.array([[0.00490776, 0.0], [0.0, 0.00490776]])},
),
T(
"GraphicalKernel",
lambda mode: GraphicalKernel(
mode=mode, local_kernel_fns={"p1": RBFKernel("norm")}
),
Expand All @@ -124,6 +138,7 @@ def MOCK_MODEL():
},
),
T(
"ProbibilityProductKernel",
lambda mode: ProbabilityProductKernel(mode=mode, guide=AutoNormal(MOCK_MODEL)),
lambda d: {"x_auto_loc": (0, 1), "x_auto_scale": (1, 2)},
lambda x: x,
Expand All @@ -139,18 +154,33 @@ def MOCK_MODEL():
# = 0.2544481
{"norm": 0.2544481},
),
T(
"RadialGaussNewtonKernel",
lambda mode: RadialGaussNewtonKernel(),
lambda d: {},
lambda key, particle, i: jnp.linalg.norm(particle), # Mock ELBO
# let
# J(z) = (2/sqrt(z.sum()))*z.T . z
# M = mean(map(J, particles)) = [[0.6612069 , 0.19051724],
# [0.19051724, 0.3387931 ]]
# diff = [1, 2] - [10, 5] = [-9, -3]
# quad_form = diff.T . M . diff = 66.89482758620689
# in
# k(x,y) = exp(-1/(2*2) * quad_form) = 5.457407430444593e-08
{"norm": 5.457407430444593e-08},
),
]


TEST_IDS = [t[0].__class__.__name__ for t in TEST_CASES]
TEST_IDS = [t.name for t in TEST_CASES]


@pytest.mark.parametrize("mode", ["norm", "vector", "matrix"])
@pytest.mark.parametrize(
"kernel, particle_info, loss_fn, kval", TEST_CASES, ids=TEST_IDS
"name, kernel, particle_info, loss_fn, kval", TEST_CASES, ids=TEST_IDS
)
@pytest.mark.parametrize("particles", [PARTICLES])
@pytest.mark.parametrize("mode", ["norm", "vector", "matrix"])
def test_kernel_forward(kernel, particles, particle_info, loss_fn, mode, kval):
def test_kernel_forward(name, kernel, particle_info, loss_fn, mode, kval):
particles = PARTICLES
if mode not in kval:
pytest.skip()
(d,) = particles[0].shape
Expand All @@ -162,11 +192,11 @@ def test_kernel_forward(kernel, particles, particle_info, loss_fn, mode, kval):


@pytest.mark.parametrize(
"kernel, particle_info, loss_fn, kval", TEST_CASES, ids=TEST_IDS
"name, kernel, particle_info, loss_fn, kval", TEST_CASES, ids=TEST_IDS
)
@pytest.mark.parametrize("mode", ["norm", "vector", "matrix"])
@pytest.mark.parametrize("particles", [PARTICLES])
def test_apply_kernel(kernel, particles, particle_info, loss_fn, mode, kval):
def test_apply_kernel(name, kernel, particle_info, loss_fn, mode, kval):
particles = PARTICLES
if mode not in kval:
pytest.skip()
(d,) = particles[0].shape
Expand Down

0 comments on commit f529fb0

Please sign in to comment.