Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using PyPesto with jax #1428

Open
MaAl13 opened this issue Jul 11, 2024 · 11 comments
Open

Using PyPesto with jax #1428

MaAl13 opened this issue Jul 11, 2024 · 11 comments
Labels
question Further information is requested

Comments

@MaAl13
Copy link

MaAl13 commented Jul 11, 2024

Hello, i want to use your package in order to do parameter estimation of ODEs and later on compute confidence intervals with profile likelihood. However to me the following code is not working on a toy example. Can you maybe tell me what i am doing wrong? I want to use the scatter search since it has been shown better convergence properties tha purely local or global methods.

import numpy as np
import jax
import time
import jax.numpy as jnp
import diffrax as dfx
import equinox as eqx
import pypesto
import pypesto.optimize as optimize
from pypesto.optimize import ScipyOptimizer
import multiprocessing

# Lotka-Volterra model
def vector_field(t, y, args):
    prey, predator = y
    α, β, γ, δ = args
    d_prey = α * prey - β * prey * predator
    d_predator = -γ * predator + δ * prey * predator
    return jnp.stack([d_prey, d_predator])

def solve(parameters, y0, ts):
    term = dfx.ODETerm(vector_field)
    solver = dfx.Tsit5()
    saveat = dfx.SaveAt(ts=ts)
    sol = dfx.diffeqsolve(
        term, solver, t0=ts[0], t1=ts[-1], dt0=0.1, y0=y0, args=parameters, saveat=saveat,
        adjoint=dfx.RecursiveCheckpointAdjoint(),
    )
    return sol.ys

# Generate synthetic data
def get_data():
    y0 = jnp.array([9.0, 9.0])
    true_parameters = jnp.array([0.1, 0.02, 0.4, 0.02])
    ts = jnp.linspace(0, 30, 20)
    values = solve(true_parameters, y0, ts)
    return y0, ts, values + 0.1 * jax.random.normal(jax.random.PRNGKey(0), values.shape)

y0, ts, noisy_values = get_data()

# Define objective function
@jax.jit
def objective(parameters):
    pred_values = solve(parameters, y0, ts)
    return jnp.sum((noisy_values - pred_values)**2)

#objective_with_grad = jax.value_and_grad(objective)

objective = pypesto.Objective(
    fun=objective,
    grad=jax.grad(objective)
)

problem1 = pypesto.Problem(objective=objective, lb=np.zeros((4, 1)), ub=np.ones((4, 1))*10)
default_ess_options = pypesto.optimize.get_default_ess_options(8, 4, local_optimizer=ScipyOptimizer(method='trust-constr'))
optimizer = pypesto.optimize.SacessOptimizer(ess_init_args = default_ess_options, max_walltime_s=600)
result_custom_problem = optimizer.minimize(problem=problem1)
@FFroehlich
Copy link
Contributor

I think you will need to be a bit more specific what you mean by "not working".

@MaAl13
Copy link
Author

MaAl13 commented Jul 12, 2024

I get the following error running the above code:

Traceback (most recent call last):
File "", line 1, in
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
exitcode = _main(fd, parent_sentinel)
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/spawn.py", line 125, in _main
prepare(preparation_data)
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/spawn.py", line 236, in prepare
_fixup_main_from_path(data['init_main_from_path'])
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/spawn.py", line 287, in _fixup_main_from_path
main_content = runpy.run_path(main_path,
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/runpy.py", line 289, in run_path
return _run_module_code(code, init_globals, run_name,
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/runpy.py", line 96, in _run_module_code
_run_code(code, mod_globals, init_globals,
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/Users/malmansto/Documents/IVF_hormone_sim/code/test_ESS.py", line 56, in
result_custom_problem = optimizer.minimize(problem=problem1)
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/site-packages/pypesto/optimize/ess/sacess.py", line 210, in minimize
with self.mp_ctx.Manager() as shmem_manager:
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/context.py", line 57, in Manager
m.start()
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/managers.py", line 562, in start
self._process.start()
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/process.py", line 121, in start
self._popen = self._Popen(self)
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/context.py", line 288, in _Popen
return Popen(process_obj)
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 32, in init
super().init(process_obj)
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/popen_fork.py", line 19, in init
self._launch(process_obj)
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 42, in _launch
prep_data = spawn.get_preparation_data(process_obj._name)
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/spawn.py", line 154, in get_preparation_data
_check_not_importing_main()
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/spawn.py", line 134, in _check_not_importing_main
raise RuntimeError('''
RuntimeError:
An attempt has been made to start a new process before the
current process has finished its bootstrapping phase.

    This probably means that you are not using fork to start your
    child processes and you have forgotten to use the proper idiom
    in the main module:

        if __name__ == '__main__':
            freeze_support()
            ...

    The "freeze_support()" line can be omitted if the program
    is not going to be frozen to produce an executable.

Traceback (most recent call last):
File "/Users/malmansto/Documents/IVF_hormone_sim/code/test_ESS.py", line 56, in
result_custom_problem = optimizer.minimize(problem=problem1)
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/site-packages/pypesto/optimize/ess/sacess.py", line 210, in minimize
with self.mp_ctx.Manager() as shmem_manager:
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/context.py", line 57, in Manager
m.start()
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/managers.py", line 566, in start
self._address = reader.recv()
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/connection.py", line 250, in recv
buf = self._recv_bytes()
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/connection.py", line 414, in _recv_bytes
buf = self._recv(4)
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/connection.py", line 383, in _recv
raise EOFError
EOFError

@FFroehlich
Copy link
Contributor

Do you get the same error with standard optimisers in pypesto? My guess that this is about serialisation of the model, where saccess relies on pickle/deepcopy, which doesn't play nicely with jax.

@MaAl13
Copy link
Author

MaAl13 commented Jul 12, 2024

Thanks for getting back so fast! :)

import numpy as np
import jax
import time
import jax.numpy as jnp
import diffrax as dfx
import equinox as eqx
import pypesto
import pypesto.optimize as optimize
from pypesto.optimize import ScipyOptimizer
import multiprocessing

# Lotka-Volterra model
def vector_field(t, y, args):
    prey, predator = y
    α, β, γ, δ = args
    d_prey = α * prey - β * prey * predator
    d_predator = -γ * predator + δ * prey * predator
    return jnp.stack([d_prey, d_predator])

def solve(parameters, y0, ts):
    term = dfx.ODETerm(vector_field)
    solver = dfx.Tsit5()
    saveat = dfx.SaveAt(ts=ts)
    sol = dfx.diffeqsolve(
        term, solver, t0=ts[0], t1=ts[-1], dt0=0.1, y0=y0, args=parameters, saveat=saveat,
        adjoint=dfx.RecursiveCheckpointAdjoint(),
    )
    return sol.ys

# Generate synthetic data
def get_data():
    y0 = jnp.array([9.0, 9.0])
    true_parameters = jnp.array([0.1, 0.02, 0.4, 0.02])
    ts = jnp.linspace(0, 30, 20)
    values = solve(true_parameters, y0, ts)
    return y0, ts, values + 0.1 * jax.random.normal(jax.random.PRNGKey(0), values.shape)

y0, ts, noisy_values = get_data()

# Define objective function
@jax.jit
def objective(parameters):
    pred_values = solve(parameters, y0, ts)
    return jnp.sum((noisy_values - pred_values)**2)

#objective_with_grad = jax.value_and_grad(objective)

objective = pypesto.Objective(
    fun=objective,
    grad=jax.grad(objective)
)

problem1 = pypesto.Problem(objective=objective, lb=np.zeros((4, 1)), ub=np.ones((4, 1))*10)
optimizer = optimize.ScipyOptimizer()
engine = pypesto.engine.SingleCoreEngine()
n_starts = 20
result = optimize.minimize(
    problem=problem1, optimizer=optimizer, n_starts=n_starts, engine=engine
)

print(result.summary())

It now runs through, but fails badly for the line search and doesn't get away from the initial guesses

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:02<00:00, 8.25it/s]

Optimization Result

  • number of starts: 20

  • best value: 9694.908203125, id=6

  • worst value: inf, id=19

  • number of non-finite values: 19

  • execution time summary:
    * Mean execution time: 0.121s
    * Maximum execution time: 1.605s, id=0
    * Minimum execution time: 0.012s, id=6

  • summary of optimizer messages:

    Count Message
    20 ABNORMAL_TERMINATION_IN_LNSRCH
  • best value found (approximately) 1 time(s)

  • number of plateaus found: 0

A summary of the best run:

Optimizer Result

  • optimizer used: <ScipyOptimizer method=L-BFGS-B options={'disp': False, 'maxfun': 1000}>
  • message: ABNORMAL_TERMINATION_IN_LNSRCH
  • number of evaluations: 2
  • time taken to optimize: 0.012s
  • startpoint: [2.46408138 2.725371 6.14559588 1.76943688]
  • endpoint: [2.46408138 2.725371 6.14559588 1.76943688]
  • final objective value: 9694.908203125
  • final gradient value: [ 8.141042e+07 7.118934e+08 9.747198e+07 -7.434768e+08]

@dweindl
Copy link
Member

dweindl commented Jul 14, 2024

Hi @MaAl13, when using multiprocessing (directly or indirectly as here through SacessOptimizer), always protect your module-level code with if __name__ == '__main__': as suggested in the error message above.

I.e., in your case:

import numpy as np
import jax
import time
import jax.numpy as jnp
import diffrax as dfx
import equinox as eqx
import pypesto
import pypesto.optimize as optimize
from pypesto.optimize import ScipyOptimizer
import multiprocessing

# Lotka-Volterra model
def vector_field(t, y, args):
    prey, predator = y
    α, β, γ, δ = args
    d_prey = α * prey - β * prey * predator
    d_predator = -γ * predator + δ * prey * predator
    return jnp.stack([d_prey, d_predator])

def solve(parameters, y0, ts):
    term = dfx.ODETerm(vector_field)
    solver = dfx.Tsit5()
    saveat = dfx.SaveAt(ts=ts)
    sol = dfx.diffeqsolve(
        term, solver, t0=ts[0], t1=ts[-1], dt0=0.1, y0=y0, args=parameters, saveat=saveat,
        adjoint=dfx.RecursiveCheckpointAdjoint(),
    )
    return sol.ys

# Generate synthetic data
def get_data():
    y0 = jnp.array([9.0, 9.0])
    true_parameters = jnp.array([0.1, 0.02, 0.4, 0.02])
    ts = jnp.linspace(0, 30, 20)
    values = solve(true_parameters, y0, ts)
    return y0, ts, values + 0.1 * jax.random.normal(jax.random.PRNGKey(0), values.shape)

def main(): 
    y0, ts, noisy_values = get_data()
    
    # Define objective function
    @jax.jit
    def objective(parameters):
        pred_values = solve(parameters, y0, ts)
        return jnp.sum((noisy_values - pred_values)**2)
    
    #objective_with_grad = jax.value_and_grad(objective)
    
    objective = pypesto.Objective(
        fun=objective,
        grad=jax.grad(objective)
    )
    
    problem1 = pypesto.Problem(objective=objective, lb=np.zeros((4, 1)), ub=np.ones((4, 1))*10)
    default_ess_options = pypesto.optimize.get_default_ess_options(8, 4, local_optimizer=ScipyOptimizer(method='trust-constr'))
    optimizer = pypesto.optimize.SacessOptimizer(ess_init_args = default_ess_options, max_walltime_s=600)
    result_custom_problem = optimizer.minimize(problem=problem1)
    
if __name__ == '__main__':
    main()

Alternatively, SacessOptimizer(..., mp_start_method="fork") might solve this specific issue, but might introduce other problems.

@MaAl13
Copy link
Author

MaAl13 commented Jul 15, 2024

Hi @dweindl, i tried running your code but still get the following error:

Traceback (most recent call last):
File "/Users/malmansto/Documents/IVF_hormone_sim/code/test_ESS.py", line 60, in
main()
File "/Users/malmansto/Documents/IVF_hormone_sim/code/test_ESS.py", line 57, in main
result_custom_problem = optimizer.minimize(problem=problem1)
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/site-packages/pypesto/optimize/ess/sacess.py", line 246, in minimize
p.start()
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/process.py", line 121, in start
self._popen = self._Popen(self)
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/context.py", line 288, in _Popen
return Popen(process_obj)
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 32, in init
super().init(process_obj)
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/popen_fork.py", line 19, in init
self._launch(process_obj)
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 47, in _launch
reduction.dump(process_obj, fp)
File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/reduction.py", line 60, in dump
ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object 'main..objective'

@dweindl
Copy link
Member

dweindl commented Jul 15, 2024

That's now the problem @FFroehlich was referring to. I am not sufficiently familiar with Jax to help there.

@FFroehlich
Copy link
Contributor

Thanks for getting back so fast! :)

import numpy as np
import jax
import time
import jax.numpy as jnp
import diffrax as dfx
import equinox as eqx
import pypesto
import pypesto.optimize as optimize
from pypesto.optimize import ScipyOptimizer
import multiprocessing

# Lotka-Volterra model
def vector_field(t, y, args):
    prey, predator = y
    α, β, γ, δ = args
    d_prey = α * prey - β * prey * predator
    d_predator = -γ * predator + δ * prey * predator
    return jnp.stack([d_prey, d_predator])

def solve(parameters, y0, ts):
    term = dfx.ODETerm(vector_field)
    solver = dfx.Tsit5()
    saveat = dfx.SaveAt(ts=ts)
    sol = dfx.diffeqsolve(
        term, solver, t0=ts[0], t1=ts[-1], dt0=0.1, y0=y0, args=parameters, saveat=saveat,
        adjoint=dfx.RecursiveCheckpointAdjoint(),
    )
    return sol.ys

# Generate synthetic data
def get_data():
    y0 = jnp.array([9.0, 9.0])
    true_parameters = jnp.array([0.1, 0.02, 0.4, 0.02])
    ts = jnp.linspace(0, 30, 20)
    values = solve(true_parameters, y0, ts)
    return y0, ts, values + 0.1 * jax.random.normal(jax.random.PRNGKey(0), values.shape)

y0, ts, noisy_values = get_data()

# Define objective function
@jax.jit
def objective(parameters):
    pred_values = solve(parameters, y0, ts)
    return jnp.sum((noisy_values - pred_values)**2)

#objective_with_grad = jax.value_and_grad(objective)

objective = pypesto.Objective(
    fun=objective,
    grad=jax.grad(objective)
)

problem1 = pypesto.Problem(objective=objective, lb=np.zeros((4, 1)), ub=np.ones((4, 1))*10)
optimizer = optimize.ScipyOptimizer()
engine = pypesto.engine.SingleCoreEngine()
n_starts = 20
result = optimize.minimize(
    problem=problem1, optimizer=optimizer, n_starts=n_starts, engine=engine
)

print(result.summary())

It now runs through, but fails badly for the line search and doesn't get away from the initial guesses

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:02<00:00, 8.25it/s]

Optimization Result

  • number of starts: 20

  • best value: 9694.908203125, id=6

  • worst value: inf, id=19

  • number of non-finite values: 19

  • execution time summary:

    • Mean execution time: 0.121s
    • Maximum execution time: 1.605s, id=0
    • Minimum execution time: 0.012s, id=6
  • summary of optimizer messages:

    Count
    Message

    20
    ABNORMAL_TERMINATION_IN_LNSRCH

  • best value found (approximately) 1 time(s)

  • number of plateaus found: 0

A summary of the best run:

Optimizer Result

  • optimizer used: <ScipyOptimizer method=L-BFGS-B options={'disp': False, 'maxfun': 1000}>
  • message: ABNORMAL_TERMINATION_IN_LNSRCH
  • number of evaluations: 2
  • time taken to optimize: 0.012s
  • startpoint: [2.46408138 2.725371 6.14559588 1.76943688]
  • endpoint: [2.46408138 2.725371 6.14559588 1.76943688]
  • final objective value: 9694.908203125
  • final gradient value: [ 8.141042e+07 7.118934e+08 9.747198e+07 -7.434768e+08]

Hard to guess what the issue here is. This looks like incorrect gradients, but likely goes beyond what we can help with in the context of a issue on github.

@FFroehlich
Copy link
Contributor

That's now the problem @FFroehlich was referring to. I am not sufficiently familiar with Jax to help there.

correct, equinox provide some guidance on serialisation, but I don't know complex it is to get that running with multiprocessing in sacess.

@MaAl13
Copy link
Author

MaAl13 commented Jul 16, 2024

Okay, so you guys would not recommend then using scatter search when doing parameter estimation of ODEs? Is there anything else that you can recommend in pypesto that is compatible with jax and has nice gloabl properties? Since with diffrax it is easy to get the gradient i think it would be a shame not to use it.

@FFroehlich
Copy link
Contributor

Okay, so you guys would not recommend then using scatter search when doing parameter estimation of ODEs? Is there anything else that you can recommend in pypesto that is compatible with jax and has nice gloabl properties? Since with diffrax it is easy to get the gradient i think it would be a shame not to use it.

I wouldn’t go that far. In both cases, the issues you encountered should be salvageable. However, since they aren’t ‘bugs’ per se and will require some effort to resolve, providing ready-made solutions is beyond the support we can offer. That said, we always welcome contributions and are happy to provide guidance.

@PaulJonasJost PaulJonasJost added the question Further information is requested label Aug 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants