-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using PyPesto with jax #1428
Comments
I think you will need to be a bit more specific what you mean by "not working". |
I get the following error running the above code: Traceback (most recent call last):
Traceback (most recent call last): |
Do you get the same error with standard optimisers in pypesto? My guess that this is about serialisation of the model, where saccess relies on |
Thanks for getting back so fast! :)
It now runs through, but fails badly for the line search and doesn't get away from the initial guesses 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:02<00:00, 8.25it/s] Optimization Result
A summary of the best run: Optimizer Result
|
Hi @MaAl13, when using multiprocessing (directly or indirectly as here through I.e., in your case: import numpy as np
import jax
import time
import jax.numpy as jnp
import diffrax as dfx
import equinox as eqx
import pypesto
import pypesto.optimize as optimize
from pypesto.optimize import ScipyOptimizer
import multiprocessing
# Lotka-Volterra model
def vector_field(t, y, args):
prey, predator = y
α, β, γ, δ = args
d_prey = α * prey - β * prey * predator
d_predator = -γ * predator + δ * prey * predator
return jnp.stack([d_prey, d_predator])
def solve(parameters, y0, ts):
term = dfx.ODETerm(vector_field)
solver = dfx.Tsit5()
saveat = dfx.SaveAt(ts=ts)
sol = dfx.diffeqsolve(
term, solver, t0=ts[0], t1=ts[-1], dt0=0.1, y0=y0, args=parameters, saveat=saveat,
adjoint=dfx.RecursiveCheckpointAdjoint(),
)
return sol.ys
# Generate synthetic data
def get_data():
y0 = jnp.array([9.0, 9.0])
true_parameters = jnp.array([0.1, 0.02, 0.4, 0.02])
ts = jnp.linspace(0, 30, 20)
values = solve(true_parameters, y0, ts)
return y0, ts, values + 0.1 * jax.random.normal(jax.random.PRNGKey(0), values.shape)
def main():
y0, ts, noisy_values = get_data()
# Define objective function
@jax.jit
def objective(parameters):
pred_values = solve(parameters, y0, ts)
return jnp.sum((noisy_values - pred_values)**2)
#objective_with_grad = jax.value_and_grad(objective)
objective = pypesto.Objective(
fun=objective,
grad=jax.grad(objective)
)
problem1 = pypesto.Problem(objective=objective, lb=np.zeros((4, 1)), ub=np.ones((4, 1))*10)
default_ess_options = pypesto.optimize.get_default_ess_options(8, 4, local_optimizer=ScipyOptimizer(method='trust-constr'))
optimizer = pypesto.optimize.SacessOptimizer(ess_init_args = default_ess_options, max_walltime_s=600)
result_custom_problem = optimizer.minimize(problem=problem1)
if __name__ == '__main__':
main() Alternatively, |
Hi @dweindl, i tried running your code but still get the following error: Traceback (most recent call last): |
That's now the problem @FFroehlich was referring to. I am not sufficiently familiar with Jax to help there. |
Hard to guess what the issue here is. This looks like incorrect gradients, but likely goes beyond what we can help with in the context of a issue on github. |
correct, |
Okay, so you guys would not recommend then using scatter search when doing parameter estimation of ODEs? Is there anything else that you can recommend in pypesto that is compatible with jax and has nice gloabl properties? Since with diffrax it is easy to get the gradient i think it would be a shame not to use it. |
I wouldn’t go that far. In both cases, the issues you encountered should be salvageable. However, since they aren’t ‘bugs’ per se and will require some effort to resolve, providing ready-made solutions is beyond the support we can offer. That said, we always welcome contributions and are happy to provide guidance. |
Hello, i want to use your package in order to do parameter estimation of ODEs and later on compute confidence intervals with profile likelihood. However to me the following code is not working on a toy example. Can you maybe tell me what i am doing wrong? I want to use the scatter search since it has been shown better convergence properties tha purely local or global methods.
The text was updated successfully, but these errors were encountered: