-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredator_prey.py
60 lines (46 loc) · 1.98 KB
/
predator_prey.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import argparse
import pandas as pd
import jax
import optax
from distributions import PredatorPrey
from execute import run
from numpyro.util import enable_x64
enable_x64(use_x64=True)
from jax.config import config
# config.update("jax_debug_nans", True)
config.update("jax_enable_x64", True)
N_PARAM = 8
def main(args):
print("Loading predator-prey data...")
data = pd.read_table("lynxhare.txt", sep=" ", names=['year', 'prey', 'pred', ''])
# data = data.loc[data.year >= 1920]
print("Setting up predator-prey model...")
dist = PredatorPrey(
data.year.values, data.pred.values, data.prey.values,
)
[n_warm, n_iter] = args.sampling_param
schedule = optax.exponential_decay(init_value=1e-3,
transition_steps=n_warm-10, decay_rate=.1, transition_begin=10)
optim = optax.adam(schedule)
run(dist, args, optim, N_PARAM, batch_fn=jax.vmap)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=0)
parser.add_argument('-m', '--max-iter', type=int, default=1)
parser.add_argument('-b', '--batch-shape', type=int, nargs=2, default=[4, 32])
parser.add_argument('-nl', '--non-linearity', type=str, default='relu')
parser.add_argument('-f', '--flow', type=str, default='coupling')
parser.add_argument('-d', '--distance', type=str, default='kld')
parser.add_argument('-nh', '--n-hidden', type=int, default=2)
parser.add_argument('-nf', '--n-flow', type=int, default=2)
parser.add_argument('-nb', '--num-bins', type=int, default=None)
parser.add_argument(
"-s", "--sampling-param", type=int, nargs=2,
help="Sampling parameters [n_warm, n_iter]",
default=[400, 100]
)
parser.add_argument('-np', '--preconditon_iter', type=int, default=400)
parser.add_argument('-s1', '--init_step_size', type=float, default=None)
parser.add_argument('-s2', '--p_init_step_size', type=float, default=None)
args = parser.parse_args()
main(args)