Skip to content

Commit

Permalink
All tests pass.
Browse files Browse the repository at this point in the history
Fixing plotting, and copying static files for all domains that have an example script
  • Loading branch information
anandtrex committed Oct 19, 2017
1 parent 71d934d commit b19692e
Show file tree
Hide file tree
Showing 12 changed files with 47 additions and 19 deletions.
5 changes: 3 additions & 2 deletions rlpy/Domains/Bicycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,11 @@ def showDomain(self, a=0, s=None):
ax.set_xlabel("Days")
for i in range(n):
handles[i].set_ydata(self.episode_data[i])
ax = handles[i].get_axes()
ax = handles[i].axes
ax.relim()
ax.autoscale_view()
plt.draw()
plt.figure("Domain").canvas.draw()
plt.figure("Domain").canvas.flush_events()


class BicycleRiding(BicycleBalancing):
Expand Down
4 changes: 3 additions & 1 deletion rlpy/Domains/BlocksWorld.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def showDomain(self, a=0):
# Put it in the back of the list
undrawn_blocks = np.hstack((undrawn_blocks, [A]))
if self.domain_fig is None:
plt.figure("Domain")
self.domain_fig = plt.imshow(
world,
cmap='BlocksWorld',
Expand All @@ -121,7 +122,8 @@ def showDomain(self, a=0):
plt.show()
else:
self.domain_fig.set_data(world)
plt.draw()
plt.figure("Domain").canvas.draw()
plt.figure("Domain").canvas.flush_events()

def showLearning(self, representation):
pass # cant show 6 dimensional value function
Expand Down
4 changes: 3 additions & 1 deletion rlpy/Domains/FiftyChain.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def showDomain(self, a=0):
s = self.state
# Draw the environment
if self.circles is None:
plt.figure("Domain")
self.domain_fig = plt.subplot(3, 1, 1)
plt.figure(1, (self.chainSize * 2 / 10.0, 2))
self.domain_fig.set_xlim(0, self.chainSize * 2 / 10.0)
Expand All @@ -193,7 +194,8 @@ def showDomain(self, a=0):
for p in self.GOAL_STATES:
self.circles[p].set_facecolor('g')
self.circles[s].set_facecolor('k')
plt.draw()
plt.figure("Domain").canvas.draw()
plt.figure("Domain").canvas.flush_events()

def showLearning(self, representation):
allStates = np.arange(0, self.chainSize)
Expand Down
3 changes: 2 additions & 1 deletion rlpy/Domains/GridWorld.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def showDomain(self, a=0, s=None):
s[0],
'k>',
markersize=20.0 - self.COLS)
plt.draw()
plt.figure("Domain").canvas.draw()
plt.figure("Domain").canvas.flush_events()

def showLearning(self, representation):
if self.valueFunction_fig is None:
Expand Down
5 changes: 3 additions & 2 deletions rlpy/Domains/HIVTreatment.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,11 @@ def showDomain(self, a=0, s=None):
ax.set_xlabel("Days")
for i in range(n):
handles[i].set_ydata(self.episode_data[i])
ax = handles[i].get_axes()
ax = handles[i].axes
ax.relim()
ax.autoscale_view()
plt.draw()
plt.figure("Domain").canvas.draw()
plt.figure("Domain").canvas.flush_events()


def dsdt(s, t, eps1, eps2):
Expand Down
4 changes: 3 additions & 1 deletion rlpy/Domains/IntruderMonitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def showDomain(self, a):
s = self.state
# Draw the environment
if self.domain_fig is None:
plt.figure("Domain")
self.domain_fig = plt.imshow(
self.map,
cmap='IntruderMonitoring',
Expand Down Expand Up @@ -264,4 +265,5 @@ def showDomain(self, a):
alpha=.7,
markeredgecolor='k',
markeredgewidth=2)
plt.draw()
plt.figure("Domain").canvas.draw()
plt.figure("Domain").canvas.flush_events()
4 changes: 3 additions & 1 deletion rlpy/Domains/PST.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def __init__(self, NUM_UAV=3):
def showDomain(self, a=0):
s = self.state
if self.domain_fig is None:
plt.figure("Domain")
self.domain_fig = plt.figure(
1, (UAVLocation.SIZE * self.dist_between_locations + 1, self.NUM_UAV + 1))
plt.show()
Expand Down Expand Up @@ -382,7 +383,8 @@ def showDomain(self, a=0):
if numHealthySurveil > 0:
self.location_rect_vis[
len(self.location_rect_vis) - 1].set_color('green')
plt.draw()
plt.figure("Domain").canvas.draw()
plt.figure("Domain").canvas.flush_events()
sleep(0.5)

def showLearning(self, representation):
Expand Down
6 changes: 4 additions & 2 deletions rlpy/Domains/PuddleWorld.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,13 @@ def showDomain(self, a=None):
self.reward_im = plt.imshow(self.reward_map, extent=(0, 1, 0, 1),
origin="lower")
self.state_mark = plt.plot(s[0], s[1], 'kd', markersize=20)
plt.draw()
plt.figure("Domain").canvas.draw()
plt.figure("Domain").canvas.flush_events()
else:
self.domain_fig = plt.figure("Domain")
self.state_mark[0].set_data([s[0]], [s[1]])
plt.draw()
plt.figure("Domain").canvas.draw()
plt.figure("Domain").canvas.flush_events()

def showLearning(self, representation):
a = np.zeros((2))
Expand Down
3 changes: 2 additions & 1 deletion rlpy/Domains/Swimmer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def showDomain(self, a=None):
else:
self.swimmer_lines.set_data(Rx, Ry)
self.action_text.set_text(str(a))
plt.draw()
plt.figure("Swimmer Domain").canvas.draw()
plt.figure("Swimmer Domain").canvas.flush_events()

def showLearning(self, representation):
good_pol = SwimmerPolicy(
Expand Down
7 changes: 5 additions & 2 deletions rlpy/Domains/SystemAdministrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def loadNetwork(self, path):
"""
_Neighbors = []
f = open(path, 'rb')
f = open(path, 'r')
reader = csv.reader(f, delimiter=',')
self.computers_num = 0
for row in reader:
Expand All @@ -142,6 +142,8 @@ def loadNetwork(self, path):

def showDomain(self, a=0):
s = self.state
plt.figure("Domain")

if self.networkGraph is None: # or self.networkPos is None:
self.networkGraph = nx.Graph()
# enumerate all computer_ids, simulatenously iterating through
Expand Down Expand Up @@ -215,7 +217,8 @@ def showDomain(self, a=0):
width=2,
style='dotted')
nx.draw_networkx_labels(self.networkGraph, self.networkPos)
plt.draw()
plt.figure("Domain").canvas.draw()
plt.figure("Domain").canvas.flush_events()

def step(self, a):
# ns = s[:] # make copy of state so as not to affect original mid-step
Expand Down
15 changes: 13 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
BSD license.
"""
import sys
import glob
import multiprocessing
try:
from setuptools import setup, Command, find_packages
Expand Down Expand Up @@ -238,8 +239,17 @@ def no_cythonize(extensions, **_ignore):
zip_safe=False,
cmdclass=cmdclass,
long_description=open('README.rst').read(),
packages=find_packages(exclude=['tests', 'tests.*']),
include_package_data=True,
data_files=[
('rlpy/Domains/GridWorldMaps', glob.glob('./rlpy/Domains/GridWorldMaps/*.txt')),
('rlpy/Domains/IntruderMonitoringMaps', glob.glob('./rlpy/Domains/IntruderMonitoringMaps/*.txt')),
('rlpy/Domains/SystemAdministratorMaps', glob.glob('./rlpy/Domains/SystemAdministratorMaps/*.txt')),
('rlpy/Domains/PinballConfigs', glob.glob('./rlpy/Domains/PinballConfigs/*.cfg')),
('rlpy/Domains/PacmanPackage/layouts', glob.glob('./rlpy/Domains/PacmanPackage/layouts/*.lay')),
('rlpy/Policies', glob.glob('./rlpy/Policies/*.mat'))
],

packages=find_packages(exclude=['tests', 'tests.*']),
install_requires=[
'numpy >= 1.7',
'scipy',
Expand All @@ -248,7 +258,8 @@ def no_cythonize(extensions, **_ignore):
'scikit-learn',
'joblib',
'hyperopt',
'pymongo'
'pymongo',
'cairocffi'
],
setup_requires=['numpy >= 1.7'],
ext_modules=extensions,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def test_cartpole():


def check_traj(domain_class, filename):
with open(filename) as f:
traj = pickle.load(f)
with open(filename, 'rb') as f:
traj = pickle.load(f, encoding='latin1')
traj_now = sample_random_trajectory(domain_class)
for i, e1, e2 in zip(list(range(len(traj_now))), traj_now, traj):
print(i)
Expand All @@ -48,7 +48,7 @@ def check_traj(domain_class, filename):

def save_trajectory(domain_class, filename):
traj = sample_random_trajectory(domain_class)
with open(filename, "w") as f:
with open(filename, 'wb') as f:
pickle.dump(traj, f)


Expand Down

0 comments on commit b19692e

Please sign in to comment.