-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathdqn.py
executable file
·83 lines (78 loc) · 3.24 KB
/
dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from pygrn import grns, problems, evolution
import os
import argparse
import numpy as np
from datetime import datetime
parser = argparse.ArgumentParser(
description='Evolve a GRN as a layer in a DQN model for solving tasks')
parser.add_argument('--no-learn', dest='learn', action='store_const',
const=False, default=True,
help='Turn off learning')
parser.add_argument('--no-evo', dest='evo', action='store_const',
const=False, default=True,
help='Turn off evolution')
parser.add_argument('--id', type=str, help='Run id for logging')
parser.add_argument('--seed', type=int, help='Random seed',
default=0)
parser.add_argument('--problem', type=str, help='Problem',
default='Gym')
parser.add_argument('--env', type=str, help='Gym environment',
default='CartPole-v0')
parser.add_argument('--steps', type=int, help='Number of training steps',
default=10000)
parser.add_argument('--warmup', type=int, help='Number of warmup steps',
default=200)
parser.add_argument('--gens', type=int, help='Number of generations',
default=50)
parser.add_argument('--grn_file', type=str, help='Experts from GRN file',
default='')
parser.add_argument('--grn_dir', type=str, help='Directory for storing GRNS',
default='grns')
parser.add_argument('--log_dir', type=str, help='Log directory',
default='logs')
args = parser.parse_args()
log_file = os.path.join(args.log_dir, 'fits_' + args.id + '.log')
pclass = eval("problems."+args.problem)
np.random.seed(args.seed)
p = pclass(log_file, args.seed, args.learn, args.env, args.steps, args.warmup)
newgrn = lambda: grns.DiffGRN()
if args.evo:
grneat = evolution.Evolution(p, newgrn, args.id, grn_dir=args.grn_dir,
log_dir=args.log_dir)
if args.grn_file:
with open(args.grn_file, 'r') as f:
pcount = 0
for g in f.readlines():
if pcount < len(grneat.population.offspring):
grn = newgrn()
grn.from_str(g)
ind = evolution.Individual(grn)
grneat.population.offspring[pcount] = ind
pcount += 1
grneat.run(args.gens)
else:
if args.grn_file:
with open(args.grn_file, 'r') as f:
gen = 0
lines = f.readlines()
for g in lines:
grn = newgrn()
grn.from_str(g)
p.generation = gen
fit = p.eval(grn)
with open(log_file, 'a') as f:
f.write('G,%s,%d,%d,%f,%d,%f,%f\n' % (
datetime.now().isoformat(),
gen, len(lines),
fit, grn.size(), fit, 0.0))
gen += 1
else:
for i in range(20):
grn = newgrn()
grn.random(p.nin, p.nout, 1)
p.generation = i
fit = p.eval(grn)
with open(log_file, 'a') as f:
f.write('G,%s,%d,%d,%f,%d,%f,%f\n' % (
datetime.now().isoformat(),
i, 20, fit, grn.size(), fit, 0.0))