-
-
Notifications
You must be signed in to change notification settings - Fork 985
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add tutorials using normalizing flows (#3302)
* Update normalizing flow intro * Add VAE with normalizing flow tutorial * Add SVI with normalizing flow tutorial * Move Zuko2Pyro to contrib.zuko * Drop unmaintained disclaimer * Add Zuko2Pyro test * Fix linting * Sort import block * Shorten comment * Address PR comments * Fix doctests * Address PR comments * Fix dummy * Fix weird linting issue * Fix dummy (I hope)
- Loading branch information
1 parent
670e9cb
commit 4f17274
Showing
8 changed files
with
766 additions
and
63 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
Zuko in Pyro | ||
============ | ||
|
||
.. automodule:: pyro.contrib.zuko | ||
:members: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
""" | ||
This file contains helpers to use `Zuko <https://zuko.readthedocs.io/>`_-based | ||
normalizing flows within Pyro piplines. | ||
Accompanying tutorials can be found at `tutorial/svi_flow_guide.ipynb` and | ||
`tutorial/vae_flow_prior.ipynb`. | ||
""" | ||
|
||
import torch | ||
from torch import Size, Tensor | ||
|
||
import pyro | ||
|
||
|
||
class ZukoToPyro(pyro.distributions.TorchDistribution): | ||
r"""Wraps a Zuko distribution as a Pyro distribution. | ||
If ``dist`` has an ``rsample_and_log_prob`` method, like Zuko's flows, it will be | ||
used when sampling instead of ``rsample``. The returned log density will be cached | ||
for later scoring. | ||
:param dist: A distribution instance. | ||
:type dist: torch.distributions.Distribution | ||
.. code-block:: python | ||
flow = zuko.flows.MAF(features=5) | ||
# flow() is a torch.distributions.Distribution | ||
dist = flow() | ||
x = dist.sample((2, 3)) | ||
log_p = dist.log_prob(x) | ||
# ZukoToPyro(flow()) is a pyro.distributions.Distribution | ||
dist = ZukoToPyro(flow()) | ||
x = dist((2, 3)) | ||
log_p = dist.log_prob(x) | ||
with pyro.plate("data", 42): | ||
z = pyro.sample("z", dist) | ||
""" | ||
|
||
def __init__(self, dist: torch.distributions.Distribution): | ||
self.dist = dist | ||
self.cache = {} | ||
|
||
@property | ||
def has_rsample(self) -> bool: | ||
return self.dist.has_rsample | ||
|
||
@property | ||
def event_shape(self) -> Size: | ||
return self.dist.event_shape | ||
|
||
@property | ||
def batch_shape(self) -> Size: | ||
return self.dist.batch_shape | ||
|
||
def __call__(self, shape: Size = ()) -> Tensor: | ||
if hasattr(self.dist, "rsample_and_log_prob"): # fast sampling + scoring | ||
x, self.cache[x] = self.dist.rsample_and_log_prob(shape) | ||
elif self.has_rsample: | ||
x = self.dist.rsample(shape) | ||
else: | ||
x = self.dist.sample(shape) | ||
|
||
return x | ||
|
||
def log_prob(self, x: Tensor) -> Tensor: | ||
if x in self.cache: | ||
return self.cache[x] | ||
else: | ||
return self.dist.log_prob(x) | ||
|
||
def expand(self, *args, **kwargs): | ||
return ZukoToPyro(self.dist.expand(*args, **kwargs)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
import pytest | ||
import torch | ||
|
||
import pyro | ||
from pyro.contrib.zuko import ZukoToPyro | ||
from pyro.infer import SVI, Trace_ELBO | ||
from pyro.optim import Adam | ||
|
||
|
||
@pytest.mark.parametrize("multivariate", [True, False]) | ||
@pytest.mark.parametrize("rsample_and_log_prob", [True, False]) | ||
def test_ZukoToPyro(multivariate: bool, rsample_and_log_prob: bool): | ||
# Distribution | ||
if multivariate: | ||
normal = torch.distributions.MultivariateNormal | ||
mu = torch.zeros(3) | ||
sigma = torch.eye(3) | ||
else: | ||
normal = torch.distributions.Normal | ||
mu = torch.zeros(()) | ||
sigma = torch.ones(()) | ||
|
||
dist = normal(mu, sigma) | ||
|
||
if rsample_and_log_prob: | ||
|
||
def dummy(self, shape): | ||
x = self.rsample(shape) | ||
return x, self.log_prob(x) | ||
|
||
dist.rsample_and_log_prob = dummy.__get__(dist) | ||
|
||
# Sample | ||
x1 = pyro.sample("x1", ZukoToPyro(dist)) | ||
|
||
assert x1.shape == dist.event_shape | ||
|
||
# Sample within plate | ||
with pyro.plate("data", 4): | ||
x2 = pyro.sample("x2", ZukoToPyro(dist)) | ||
|
||
assert x2.shape == (4, *dist.event_shape) | ||
|
||
# SVI | ||
def model(): | ||
pyro.sample("a", ZukoToPyro(dist)) | ||
|
||
with pyro.plate("data", 4): | ||
pyro.sample("b", ZukoToPyro(dist)) | ||
|
||
def guide(): | ||
mu_ = pyro.param("mu", mu) | ||
sigma_ = pyro.param("sigma", sigma) | ||
|
||
pyro.sample("a", ZukoToPyro(normal(mu_, sigma_))) | ||
|
||
with pyro.plate("data", 4): | ||
pyro.sample("b", ZukoToPyro(normal(mu_, sigma_))) | ||
|
||
svi = SVI(model, guide, optim=Adam({"lr": 1e-3}), loss=Trace_ELBO()) | ||
svi.step() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.