Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fetch rng_key using prng_key message in block handler. #1957

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

tillahoffmann
Copy link
Contributor

This PR seeks to address #1657 by fetching a rng_key after setting msg["stop"] = True in the block handler. This approach ensures that the values obtained from sample statements are always the same as if no block had been added to the stack.

We can simplify a number of block(seed(model, rng_seed=dummy_seed), ...) to block(model, ...) statements for blocked auto guides because the rng_keys now propagate.

Caveat: This currently only works for block(hide=...) because block(expose=...) blocks everything but the specified site, including prng_key messages. Not sure this is the right solution, but maybe a starting point.

Illustration

import numpyro
from numpyro.handlers import block, seed
import subprocess
import traceback


print("commit hash", subprocess.check_output(["git", "rev-parse", "HEAD"], text=True))


def model(prefix):
    try:
        samples = (
            numpyro.sample("x", numpyro.distributions.Normal(0, 1)),
            numpyro.sample("y", numpyro.distributions.Normal(0, 1)),
            numpyro.sample("z", numpyro.distributions.Normal(0, 1)),
        )
        print(f"{prefix}:\n{', '.join(f'{x:.2f}' for x in samples)}")
    except Exception:
        print(f"{prefix}: ERROR\n{traceback.format_exc()}")
    print()


# This always works.
with seed(rng_seed=0):
    model("seed 0")

# Re-seeding using the parent `seed` handler leads to different samples because
# `numpyro.prng_key` is executed at the time of handler construction instead of
# during model execution.
with seed(rng_seed=0), block(hide=["y"]), seed(rng_seed=numpyro.prng_key()):
    model("seed 0, block y, seed with `prng_key`")

# This fails on `master`.
with seed(rng_seed=0), block(hide=["y"]):
    model("seed 0, block y")

Results on master Branch

commit hash 4704656886185fe3ba2df2f25af5910b5441255e

seed 0:
-1.25, -0.59, 0.49

seed 0, block y, seed with `prng_key`:
1.28, 2.13, -0.44

seed 0, block y: ERROR
Traceback (most recent call last):
  File "/Users/till/git/numpyro/playground/issue_1657.py", line 14, in model
    numpyro.sample("y", numpyro.distributions.Normal(0, 1)),
  File "/Users/till/git/numpyro/numpyro/primitives.py", line 250, in sample
    msg = apply_stack(initial_msg)
  File "/Users/till/git/numpyro/numpyro/primitives.py", line 61, in apply_stack
    default_process_message(msg)
  File "/Users/till/git/numpyro/numpyro/primitives.py", line 32, in default_process_message
    msg["value"], msg["intermediates"] = msg["fn"](
  File "/Users/till/git/numpyro/numpyro/distributions/distribution.py", line 393, in __call__
    return self.sample_with_intermediates(key, *args, **kwargs)
  File "/Users/till/git/numpyro/numpyro/distributions/distribution.py", line 351, in sample_with_intermediates
    return self.sample(key, sample_shape=sample_shape), []
  File "/Users/till/git/numpyro/numpyro/distributions/continuous.py", line 2190, in sample
    assert is_prng_key(key)
AssertionError

Results on seed-block Branch

commit hash 102e08e6a6a4d24bc1e15fd7179ab07827173bf9

seed 0:
-1.25, -0.59, 0.49

seed 0, block y, seed with `prng_key`:
1.28, 2.13, -0.44

seed 0, block y:
-1.25, -0.59, 0.49

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful solution! Thanks, Till! Could you grep all block handlers and remove the corresponding seeds?

@@ -310,8 +310,17 @@ def __init__(
super(block, self).__init__(fn)

def process_message(self, msg: Message) -> None:
if self.hide_fn(msg):
msg["stop"] = True
if not self.hide_fn(msg):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe only hide when the msg is not prng_key?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants