Skip to content

Commit

Permalink
Merge pull request #123 from danielward27/adapt_interval_bug
Browse files Browse the repository at this point in the history
Adapt bisection interval bug fix
  • Loading branch information
danielward27 authored Dec 15, 2023
2 parents da1ba84 + 92f3458 commit ec870ee
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions flowjax/bisection_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,14 @@ def cond_fn(state):

def body_fn(state):
sign = state.lower_fn_sign # Note we know the signs match from cond_fn

lower_update = jnp.where(sign == 1, state.lower - state.expand_by, state.upper)
upper_update = jnp.where(sign == 1, state.lower, state.upper + state.expand_by)
return _State(
lower=jnp.where(sign == 1, state.lower - state.expand_by, state.upper),
upper=jnp.where(sign == 1, state.lower, state.upper + state.expand_by),
lower=lower_update,
upper=upper_update,
expand_by=state.expand_by * expand_factor,
lower_fn_sign=jnp.sign(func(state.lower)),
upper_fn_sign=jnp.sign(func(state.upper)),
lower_fn_sign=jnp.sign(func(lower_update)),
upper_fn_sign=jnp.sign(func(upper_update)),
iteration=state.iteration + 1,
)

Expand Down

0 comments on commit ec870ee

Please sign in to comment.