Skip to content

Commit

Permalink
Add tutorials using normalizing flows (#3302)
Browse files Browse the repository at this point in the history
* Update normalizing flow intro

* Add VAE with normalizing flow tutorial

* Add SVI with normalizing flow tutorial

* Move Zuko2Pyro to contrib.zuko

* Drop unmaintained disclaimer

* Add Zuko2Pyro test

* Fix linting

* Sort import block

* Shorten comment

* Address PR comments

* Fix doctests

* Address PR comments

* Fix dummy

* Fix weird linting issue

* Fix dummy (I hope)
  • Loading branch information
francois-rozet authored Jan 14, 2024
1 parent 670e9cb commit 4f17274
Show file tree
Hide file tree
Showing 8 changed files with 766 additions and 63 deletions.
5 changes: 5 additions & 0 deletions docs/source/contrib.zuko.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Zuko in Pyro
============

.. automodule:: pyro.contrib.zuko
:members:
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Pyro Documentation
contrib.randomvariable
contrib.timeseries
contrib.tracking
contrib.zuko


Indices and tables
Expand Down
81 changes: 81 additions & 0 deletions pyro/contrib/zuko.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

"""
This file contains helpers to use `Zuko <https://zuko.readthedocs.io/>`_-based
normalizing flows within Pyro piplines.
Accompanying tutorials can be found at `tutorial/svi_flow_guide.ipynb` and
`tutorial/vae_flow_prior.ipynb`.
"""

import torch
from torch import Size, Tensor

import pyro


class ZukoToPyro(pyro.distributions.TorchDistribution):
r"""Wraps a Zuko distribution as a Pyro distribution.
If ``dist`` has an ``rsample_and_log_prob`` method, like Zuko's flows, it will be
used when sampling instead of ``rsample``. The returned log density will be cached
for later scoring.
:param dist: A distribution instance.
:type dist: torch.distributions.Distribution
.. code-block:: python
flow = zuko.flows.MAF(features=5)
# flow() is a torch.distributions.Distribution
dist = flow()
x = dist.sample((2, 3))
log_p = dist.log_prob(x)
# ZukoToPyro(flow()) is a pyro.distributions.Distribution
dist = ZukoToPyro(flow())
x = dist((2, 3))
log_p = dist.log_prob(x)
with pyro.plate("data", 42):
z = pyro.sample("z", dist)
"""

def __init__(self, dist: torch.distributions.Distribution):
self.dist = dist
self.cache = {}

@property
def has_rsample(self) -> bool:
return self.dist.has_rsample

@property
def event_shape(self) -> Size:
return self.dist.event_shape

@property
def batch_shape(self) -> Size:
return self.dist.batch_shape

def __call__(self, shape: Size = ()) -> Tensor:
if hasattr(self.dist, "rsample_and_log_prob"): # fast sampling + scoring
x, self.cache[x] = self.dist.rsample_and_log_prob(shape)
elif self.has_rsample:
x = self.dist.rsample(shape)
else:
x = self.dist.sample(shape)

return x

def log_prob(self, x: Tensor) -> Tensor:
if x in self.cache:
return self.cache[x]
else:
return self.dist.log_prob(x)

def expand(self, *args, **kwargs):
return ZukoToPyro(self.dist.expand(*args, **kwargs))
65 changes: 65 additions & 0 deletions tests/contrib/test_zuko.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0


import pytest
import torch

import pyro
from pyro.contrib.zuko import ZukoToPyro
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam


@pytest.mark.parametrize("multivariate", [True, False])
@pytest.mark.parametrize("rsample_and_log_prob", [True, False])
def test_ZukoToPyro(multivariate: bool, rsample_and_log_prob: bool):
# Distribution
if multivariate:
normal = torch.distributions.MultivariateNormal
mu = torch.zeros(3)
sigma = torch.eye(3)
else:
normal = torch.distributions.Normal
mu = torch.zeros(())
sigma = torch.ones(())

dist = normal(mu, sigma)

if rsample_and_log_prob:

def dummy(self, shape):
x = self.rsample(shape)
return x, self.log_prob(x)

dist.rsample_and_log_prob = dummy.__get__(dist)

# Sample
x1 = pyro.sample("x1", ZukoToPyro(dist))

assert x1.shape == dist.event_shape

# Sample within plate
with pyro.plate("data", 4):
x2 = pyro.sample("x2", ZukoToPyro(dist))

assert x2.shape == (4, *dist.event_shape)

# SVI
def model():
pyro.sample("a", ZukoToPyro(dist))

with pyro.plate("data", 4):
pyro.sample("b", ZukoToPyro(dist))

def guide():
mu_ = pyro.param("mu", mu)
sigma_ = pyro.param("sigma", sigma)

pyro.sample("a", ZukoToPyro(normal(mu_, sigma_)))

with pyro.plate("data", 4):
pyro.sample("b", ZukoToPyro(normal(mu_, sigma_)))

svi = SVI(model, guide, optim=Adam({"lr": 1e-3}), loss=Trace_ELBO())
svi.step()
4 changes: 3 additions & 1 deletion tutorial/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ List of Tutorials
jit
svi_horovod
svi_lightning
svi_flow_guide

.. toctree::
:maxdepth: 1
Expand All @@ -106,7 +107,8 @@ List of Tutorials
vae
ss-vae
cvae
normalizing_flows_i
normalizing_flows_intro
vae_flow_prior
dmm
air
cevae
Expand Down
Loading

0 comments on commit 4f17274

Please sign in to comment.