Skip to content

Commit

Permalink
SteinVI: Recompute Score Function* For Each Particle Interaction in t…
Browse files Browse the repository at this point in the history
…he Attractive Force (#1947)

* updated attractive force computation and added comments

* removed comment

* fomatting with ruff 0.9 up from 0.7.1
  • Loading branch information
OlaRonning authored Jan 12, 2025
1 parent d1ca868 commit 986da95
Show file tree
Hide file tree
Showing 26 changed files with 179 additions and 195 deletions.
6 changes: 3 additions & 3 deletions examples/hmcecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def run_hmc(mcmc_key, args, data, obs, kernel):


def main(args):
assert (
11_000_000 >= args.num_datapoints
), "11,000,000 data points in the Higgs dataset"
assert 11_000_000 >= args.num_datapoints, (
"11,000,000 data points in the Higgs dataset"
)
# full dataset takes hours for plain hmc!
if args.dataset == "higgs":
_, fetch = load_dataset(
Expand Down
6 changes: 3 additions & 3 deletions examples/ssbvm_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def main(args):
parser.add_argument("--device", default="gpu", type=str, help='use "cpu" or "gpu".')

args = parser.parse_args()
assert all(
aa in AMINO_ACIDS for aa in args.amino_acids
), f"{list(filter(lambda aa: aa not in AMINO_ACIDS, args.amino_acids))} are not amino acids."
assert all(aa in AMINO_ACIDS for aa in args.amino_acids), (
f"{list(filter(lambda aa: aa not in AMINO_ACIDS, args.amino_acids))} are not amino acids."
)
main(args)
5 changes: 2 additions & 3 deletions examples/stein_bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from numpyro.contrib.einstein import MixtureGuidePredictive, RBFKernel, SteinVI
from numpyro.distributions import Gamma, Normal
from numpyro.examples.datasets import BOSTON_HOUSING, load_dataset
from numpyro.infer import init_to_uniform
from numpyro.infer.autoguide import AutoNormal
from numpyro.optim import Adagrad

Expand Down Expand Up @@ -121,12 +120,12 @@ def main(args):
rng_key, inf_key = random.split(inf_key)

# We find that SteinVI benefits from a small radius when inferring BNNs.
guide = AutoNormal(model, init_loc_fn=partial(init_to_uniform, radius=0.1))
guide = AutoNormal(model)

stein = SteinVI(
model,
guide,
Adagrad(0.5),
Adagrad(1.0),
RBFKernel(),
repulsion_temperature=args.repulsion,
num_stein_particles=args.num_stein_particles,
Expand Down
10 changes: 5 additions & 5 deletions notebooks/source/lotka_volterra_multiple.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -404,10 +404,10 @@
"source": [
"print(f\"The dataset has the shape {data.shape}, (n_datasets, n_points, n_observables)\")\n",
"print(f\"The time matrix has the shape {ts.shape}, (n_datasets, n_timepoints)\")\n",
"print(f\"The time matrix has different spacing between timepoints: \\n {ts[:,:5]}\")\n",
"print(f\"The final timepoints are: {jnp.nanmax(ts,1)} years.\")\n",
"print(f\"The time matrix has different spacing between timepoints: \\n {ts[:, :5]}\")\n",
"print(f\"The final timepoints are: {jnp.nanmax(ts, 1)} years.\")\n",
"print(\n",
" f\"The dataset has {jnp.sum(jnp.isnan(data))/jnp.size(data):.0%} missing observations\"\n",
" f\"The dataset has {jnp.sum(jnp.isnan(data)) / jnp.size(data):.0%} missing observations\"\n",
")\n",
"print(f\"True params mean: {sample['theta'][0]}\")"
]
Expand Down Expand Up @@ -550,7 +550,7 @@
"mcmc.print_summary()\n",
"\n",
"print(f\"True params mean: {sample['theta'][0]}\")\n",
"print(f\"Estimated params mean: {jnp.mean(mcmc.get_samples()['theta'], axis = 0)}\")"
"print(f\"Estimated params mean: {jnp.mean(mcmc.get_samples()['theta'], axis=0)}\")"
]
},
{
Expand Down Expand Up @@ -591,7 +591,7 @@
"\n",
"\n",
"print(f\"True params mean: {sample['theta'][0]}\")\n",
"print(f\"Estimated params mean: {jnp.mean(mcmc.get_samples()['theta'], axis = 0)}\")"
"print(f\"Estimated params mean: {jnp.mean(mcmc.get_samples()['theta'], axis=0)}\")"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/ordinal_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
"print(df.Y.value_counts())\n",
"\n",
"for i in range(nclasses):\n",
" print(f\"mean(X) for Y == {i}: {X[np.where(Y==i)].mean():.3f}\")"
" print(f\"mean(X) for Y == {i}: {X[np.where(Y == i)].mean():.3f}\")"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _subs_wrapper(subs_map, i, length, site):
)
else:
raise RuntimeError(
f"Something goes wrong. Expected ndim = {fn_ndim} or {fn_ndim+1},"
f"Something goes wrong. Expected ndim = {fn_ndim} or {fn_ndim + 1},"
f" but got {value_ndim}. This might happen when you use nested scan,"
" which is currently not supported. Please report the issue to us!"
)
Expand Down
6 changes: 2 additions & 4 deletions numpyro/contrib/ecs_proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,11 @@

TaylorTwoProxyState = namedtuple(
"TaylorProxyState",
"ref_subsample_log_liks,"
"ref_subsample_log_lik_grads,"
"ref_subsample_log_lik_hessians",
"ref_subsample_log_liks,ref_subsample_log_lik_grads,ref_subsample_log_lik_hessians",
)

TaylorOneProxyState = namedtuple(
"TaylorOneProxyState", "ref_subsample_log_liks," "ref_subsample_log_lik_grads,"
"TaylorOneProxyState", "ref_subsample_log_liks,ref_subsample_log_lik_grads,"
)


Expand Down
140 changes: 66 additions & 74 deletions numpyro/contrib/einstein/steinvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from jax import grad, numpy as jnp, random, tree, vmap
from jax.flatten_util import ravel_pytree
from jax.lax import scan

from numpyro import handlers
from numpyro.contrib.einstein.stein_loss import SteinLoss
Expand All @@ -33,7 +34,6 @@ def _numel(shape):
class SteinVI:
"""Variational inference with Stein mixtures inference.
**Example:**
.. doctest::
Expand Down Expand Up @@ -138,9 +138,9 @@ def __init__(
if isinstance(guide.init_loc_fn, partial):
init_fn_name = guide.init_loc_fn.func.__name__
if init_fn_name == "init_to_uniform":
assert (
guide.init_loc_fn.keywords.get("radius", None) != 0.0
), init_loc_error_message
assert guide.init_loc_fn.keywords.get("radius", None) != 0.0, (
init_loc_error_message
)
else:
init_fn_name = guide.init_loc_fn.__name__
assert init_fn_name not in [
Expand Down Expand Up @@ -230,104 +230,84 @@ def local_trace(key):
return vmap(local_trace)(random.split(particle_seed, self.num_stein_particles))

def _svgd_loss_and_grads(self, rng_key, unconstr_params, *args, **kwargs):
# 0. Separate model and guide parameters, since only guide parameters are updated using Stein
non_mixture_uparams = { # Includes any marked guide parameters and all model parameters
# Separate model and guide parameters, since only guide parameters are updated using Stein
# Split parameters into model and guide components - only unflagged guide parameters are
# optimized via Stein forces.
nonmix_uparams = { # Includes any marked guide parameters and all model parameters
p: v
for p, v in unconstr_params.items()
if p not in self.guide_sites or self.non_mixture_params_fn(p)
}

stein_uparams = {
p: v for p, v in unconstr_params.items() if p not in non_mixture_uparams
p: v for p, v in unconstr_params.items() if p not in nonmix_uparams
}

# 1. Collect each guide parameter into monolithic particles that capture correlations
# between parameter values across each individual particle
# Collect guide parameters into a monolithic particle.
stein_particles, unravel_pytree, unravel_pytree_batched = batch_ravel_pytree(
stein_uparams, nbatch_dims=1
)

# Kernel behavior varies based on particle site locations. The particle_info dictionary
# maps site names to their corresponding dimensional ranges as (start, end) tuples.
particle_info, _ = self._calc_particle_info(
stein_uparams, stein_particles.shape[0]
)
attractive_key, classic_key = random.split(rng_key)

def particle_transform_fn(particle):
params = unravel_pytree(particle)
ctparams = self.constrain_fn(self.particle_transform_fn(params))
ctparticle, _ = ravel_pytree(ctparams)
return ctparticle

# 2. Calculate gradients for each particle
def kernel_particles_loss_fn(rng_key, particles):
particle_keys = random.split(rng_key, self.stein_loss.stein_num_particles)
grads = vmap(
lambda i: grad(
lambda particle: self.stein_loss.particle_loss(
rng_key=particle_keys[i],
model=handlers.scale(
self._inference_model, self.loss_temperature
),
guide=self.guide,
selected_particle=self.constrain_fn(unravel_pytree(particle)),
unravel_pytree=unravel_pytree,
flat_particles=vmap(particle_transform_fn)(particles),
select_index=i,
model_args=args,
model_kwargs=kwargs,
param_map=self.constrain_fn(non_mixture_uparams),
)
)(particles[i])
)(jnp.arange(self.stein_loss.stein_num_particles))

return grads

# 2.1 Compute particle gradients (for attractive force)
particle_ljp_grads = kernel_particles_loss_fn(attractive_key, stein_particles)

# 2.3 Lift particles to constraint space
ctstein_particles = vmap(particle_transform_fn)(stein_particles)

# 2.4 Compute non-mixture parameter gradients
non_mixture_param_grads = grad(
lambda cps: -self.stein_loss.loss(
classic_key,
self.constrain_fn(cps),
handlers.scale(self._inference_model, self.loss_temperature),
self.guide,
unravel_pytree_batched(ctstein_particles),
*args,
**kwargs,
)
)(non_mixture_uparams)
model = handlers.scale(self._inference_model, self.loss_temperature)

# 3. Calculate kernel of particles
def loss_fn(particle, i):
def stein_loss_fn(key, particle, particle_idx):
return self.stein_loss.particle_loss(
rng_key=rng_key,
model=handlers.scale(self._inference_model, self.loss_temperature),
rng_key=key,
model=model,
guide=self.guide,
# Stein particles evolve in unconstrained space, but gradient computations must account
# for the transformation to constrained space
selected_particle=self.constrain_fn(unravel_pytree(particle)),
unravel_pytree=unravel_pytree,
flat_particles=ctstein_particles,
select_index=i,
flat_particles=vmap(particle_transform_fn)(stein_particles),
select_index=particle_idx,
model_args=args,
model_kwargs=kwargs,
param_map=self.constrain_fn(non_mixture_uparams),
param_map=self.constrain_fn(nonmix_uparams),
)

kernel = self.kernel_fn.compute(
rng_key, stein_particles, particle_info, loss_fn
rng_key, stein_particles, particle_info, stein_loss_fn
)
attractive_key, classic_key = random.split(rng_key)

# 4. Calculate the attractive force and repulsive force on the particles
attractive_force = vmap(
lambda y: jnp.sum(
vmap(
lambda x, x_ljp_grad: self._apply_kernel(kernel, x, y, x_ljp_grad)
)(stein_particles, particle_ljp_grads),
axis=0,
)
)(stein_particles)
def compute_attr_force(rng_key, particles):
# Second term of eq. 9 from https://arxiv.org/pdf/2410.22948.
def body(attr_force, state, y):
key, x, i = state
x_grad = grad(stein_loss_fn, argnums=1)(key, x, i)
attr_force = attr_force + self._apply_kernel(kernel, x, y, x_grad)
return attr_force, None

particle_keys = random.split(rng_key, self.stein_loss.stein_num_particles)
init = jnp.zeros_like(particles[0])
idxs = jnp.arange(self.num_stein_particles)

attr_force, _ = vmap(
lambda y, key: scan(
partial(body, y=y),
init,
(random.split(key, self.num_stein_particles), particles, idxs),
)
)(particles, particle_keys)

return attr_force

attractive_force = compute_attr_force(attractive_key, stein_particles)

# Third term of eq. 9 from https://arxiv.org/pdf/2410.22948.
repulsive_force = vmap(
lambda y: jnp.mean(
vmap(
Expand All @@ -338,15 +318,27 @@ def loss_fn(particle, i):
)
)(stein_particles)

# 6. Compute the stein force
particle_grads = attractive_force + repulsive_force

# 7. Decompose the monolithic particle forces back to concrete parameter values
stein_param_grads = unravel_pytree_batched(particle_grads)
# Compute non-mixture parameter gradients.
nonmix_uparam_grads = grad(
lambda cps: -self.stein_loss.loss(
classic_key,
self.constrain_fn(cps),
model,
self.guide,
unravel_pytree_batched(vmap(particle_transform_fn)(stein_particles)),
*args,
**kwargs,
)
)(nonmix_uparams)

# Decompose the monolithic particle forces back to concrete parameter values.
stein_uparam_grads = unravel_pytree_batched(particle_grads)

# 8. Return loss and gradients (based on parameter forces)
# Return loss and gradients (based on parameter forces).
res_grads = tree.map(
lambda x: -x, {**non_mixture_param_grads, **stein_param_grads}
lambda x: -x, {**nonmix_uparam_grads, **stein_uparam_grads}
)
return jnp.linalg.norm(particle_grads), res_grads

Expand Down
6 changes: 3 additions & 3 deletions numpyro/contrib/funsor/enum_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,9 +436,9 @@ class BaseEnumMessenger(NamedMessenger):
"""

def __init__(self, fn=None, first_available_dim=None):
assert (
first_available_dim is None or first_available_dim < 0
), first_available_dim
assert first_available_dim is None or first_available_dim < 0, (
first_available_dim
)
self.first_available_dim = first_available_dim
super().__init__(fn)

Expand Down
6 changes: 3 additions & 3 deletions numpyro/contrib/tfp/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def log_prob_fn(x):
class _TFPKernelMeta(ABCMeta):
def __getitem__(cls, kernel_class):
assert issubclass(kernel_class, tfp.mcmc.TransitionKernel)
assert (
"target_log_prob_fn" in inspect.getfullargspec(kernel_class).args
), f"the first argument of {kernel_class} must be `target_log_prob_fn`"
assert "target_log_prob_fn" in inspect.getfullargspec(kernel_class).args, (
f"the first argument of {kernel_class} must be `target_log_prob_fn`"
)

_PyroKernel = type(kernel_class.__name__, (TFPKernel,), {})
_PyroKernel.kernel_class = kernel_class
Expand Down
Loading

0 comments on commit 986da95

Please sign in to comment.