From b19692ee3deea344ea91160ed70842cfde5c6199 Mon Sep 17 00:00:00 2001 From: Anand Date: Thu, 19 Oct 2017 22:47:28 +0200 Subject: [PATCH] All tests pass. Fixing plotting, and copying static files for all domains that have an example script --- rlpy/Domains/Bicycle.py | 5 +++-- rlpy/Domains/BlocksWorld.py | 4 +++- rlpy/Domains/FiftyChain.py | 4 +++- rlpy/Domains/GridWorld.py | 3 ++- rlpy/Domains/HIVTreatment.py | 5 +++-- rlpy/Domains/IntruderMonitoring.py | 4 +++- rlpy/Domains/PST.py | 4 +++- rlpy/Domains/PuddleWorld.py | 6 ++++-- rlpy/Domains/Swimmer.py | 3 ++- rlpy/Domains/SystemAdministrator.py | 7 +++++-- setup.py | 15 +++++++++++++-- tests/test_cartpole.py | 6 +++--- 12 files changed, 47 insertions(+), 19 deletions(-) diff --git a/rlpy/Domains/Bicycle.py b/rlpy/Domains/Bicycle.py index a7bcd434..4b8428fc 100644 --- a/rlpy/Domains/Bicycle.py +++ b/rlpy/Domains/Bicycle.py @@ -182,10 +182,11 @@ def showDomain(self, a=0, s=None): ax.set_xlabel("Days") for i in range(n): handles[i].set_ydata(self.episode_data[i]) - ax = handles[i].get_axes() + ax = handles[i].axes ax.relim() ax.autoscale_view() - plt.draw() + plt.figure("Domain").canvas.draw() + plt.figure("Domain").canvas.flush_events() class BicycleRiding(BicycleBalancing): diff --git a/rlpy/Domains/BlocksWorld.py b/rlpy/Domains/BlocksWorld.py index 4ef3c799..89143954 100644 --- a/rlpy/Domains/BlocksWorld.py +++ b/rlpy/Domains/BlocksWorld.py @@ -109,6 +109,7 @@ def showDomain(self, a=0): # Put it in the back of the list undrawn_blocks = np.hstack((undrawn_blocks, [A])) if self.domain_fig is None: + plt.figure("Domain") self.domain_fig = plt.imshow( world, cmap='BlocksWorld', @@ -121,7 +122,8 @@ def showDomain(self, a=0): plt.show() else: self.domain_fig.set_data(world) - plt.draw() + plt.figure("Domain").canvas.draw() + plt.figure("Domain").canvas.flush_events() def showLearning(self, representation): pass # cant show 6 dimensional value function diff --git a/rlpy/Domains/FiftyChain.py b/rlpy/Domains/FiftyChain.py index e5ccce68..803b9fad 100644 --- a/rlpy/Domains/FiftyChain.py +++ b/rlpy/Domains/FiftyChain.py @@ -172,6 +172,7 @@ def showDomain(self, a=0): s = self.state # Draw the environment if self.circles is None: + plt.figure("Domain") self.domain_fig = plt.subplot(3, 1, 1) plt.figure(1, (self.chainSize * 2 / 10.0, 2)) self.domain_fig.set_xlim(0, self.chainSize * 2 / 10.0) @@ -193,7 +194,8 @@ def showDomain(self, a=0): for p in self.GOAL_STATES: self.circles[p].set_facecolor('g') self.circles[s].set_facecolor('k') - plt.draw() + plt.figure("Domain").canvas.draw() + plt.figure("Domain").canvas.flush_events() def showLearning(self, representation): allStates = np.arange(0, self.chainSize) diff --git a/rlpy/Domains/GridWorld.py b/rlpy/Domains/GridWorld.py index e430a006..bdb31ef8 100644 --- a/rlpy/Domains/GridWorld.py +++ b/rlpy/Domains/GridWorld.py @@ -130,7 +130,8 @@ def showDomain(self, a=0, s=None): s[0], 'k>', markersize=20.0 - self.COLS) - plt.draw() + plt.figure("Domain").canvas.draw() + plt.figure("Domain").canvas.flush_events() def showLearning(self, representation): if self.valueFunction_fig is None: diff --git a/rlpy/Domains/HIVTreatment.py b/rlpy/Domains/HIVTreatment.py index fc2fb110..da4e9015 100644 --- a/rlpy/Domains/HIVTreatment.py +++ b/rlpy/Domains/HIVTreatment.py @@ -130,10 +130,11 @@ def showDomain(self, a=0, s=None): ax.set_xlabel("Days") for i in range(n): handles[i].set_ydata(self.episode_data[i]) - ax = handles[i].get_axes() + ax = handles[i].axes ax.relim() ax.autoscale_view() - plt.draw() + plt.figure("Domain").canvas.draw() + plt.figure("Domain").canvas.flush_events() def dsdt(s, t, eps1, eps2): diff --git a/rlpy/Domains/IntruderMonitoring.py b/rlpy/Domains/IntruderMonitoring.py index fc8415d1..8773008b 100644 --- a/rlpy/Domains/IntruderMonitoring.py +++ b/rlpy/Domains/IntruderMonitoring.py @@ -228,6 +228,7 @@ def showDomain(self, a): s = self.state # Draw the environment if self.domain_fig is None: + plt.figure("Domain") self.domain_fig = plt.imshow( self.map, cmap='IntruderMonitoring', @@ -264,4 +265,5 @@ def showDomain(self, a): alpha=.7, markeredgecolor='k', markeredgewidth=2) - plt.draw() + plt.figure("Domain").canvas.draw() + plt.figure("Domain").canvas.flush_events() diff --git a/rlpy/Domains/PST.py b/rlpy/Domains/PST.py index d6a149f8..e7833110 100644 --- a/rlpy/Domains/PST.py +++ b/rlpy/Domains/PST.py @@ -205,6 +205,7 @@ def __init__(self, NUM_UAV=3): def showDomain(self, a=0): s = self.state if self.domain_fig is None: + plt.figure("Domain") self.domain_fig = plt.figure( 1, (UAVLocation.SIZE * self.dist_between_locations + 1, self.NUM_UAV + 1)) plt.show() @@ -382,7 +383,8 @@ def showDomain(self, a=0): if numHealthySurveil > 0: self.location_rect_vis[ len(self.location_rect_vis) - 1].set_color('green') - plt.draw() + plt.figure("Domain").canvas.draw() + plt.figure("Domain").canvas.flush_events() sleep(0.5) def showLearning(self, representation): diff --git a/rlpy/Domains/PuddleWorld.py b/rlpy/Domains/PuddleWorld.py index ff7f3cf8..faf5feeb 100644 --- a/rlpy/Domains/PuddleWorld.py +++ b/rlpy/Domains/PuddleWorld.py @@ -112,11 +112,13 @@ def showDomain(self, a=None): self.reward_im = plt.imshow(self.reward_map, extent=(0, 1, 0, 1), origin="lower") self.state_mark = plt.plot(s[0], s[1], 'kd', markersize=20) - plt.draw() + plt.figure("Domain").canvas.draw() + plt.figure("Domain").canvas.flush_events() else: self.domain_fig = plt.figure("Domain") self.state_mark[0].set_data([s[0]], [s[1]]) - plt.draw() + plt.figure("Domain").canvas.draw() + plt.figure("Domain").canvas.flush_events() def showLearning(self, representation): a = np.zeros((2)) diff --git a/rlpy/Domains/Swimmer.py b/rlpy/Domains/Swimmer.py index ec12a7cb..a1d9c8ba 100644 --- a/rlpy/Domains/Swimmer.py +++ b/rlpy/Domains/Swimmer.py @@ -130,7 +130,8 @@ def showDomain(self, a=None): else: self.swimmer_lines.set_data(Rx, Ry) self.action_text.set_text(str(a)) - plt.draw() + plt.figure("Swimmer Domain").canvas.draw() + plt.figure("Swimmer Domain").canvas.flush_events() def showLearning(self, representation): good_pol = SwimmerPolicy( diff --git a/rlpy/Domains/SystemAdministrator.py b/rlpy/Domains/SystemAdministrator.py index 140c7516..66350ec4 100644 --- a/rlpy/Domains/SystemAdministrator.py +++ b/rlpy/Domains/SystemAdministrator.py @@ -130,7 +130,7 @@ def loadNetwork(self, path): """ _Neighbors = [] - f = open(path, 'rb') + f = open(path, 'r') reader = csv.reader(f, delimiter=',') self.computers_num = 0 for row in reader: @@ -142,6 +142,8 @@ def loadNetwork(self, path): def showDomain(self, a=0): s = self.state + plt.figure("Domain") + if self.networkGraph is None: # or self.networkPos is None: self.networkGraph = nx.Graph() # enumerate all computer_ids, simulatenously iterating through @@ -215,7 +217,8 @@ def showDomain(self, a=0): width=2, style='dotted') nx.draw_networkx_labels(self.networkGraph, self.networkPos) - plt.draw() + plt.figure("Domain").canvas.draw() + plt.figure("Domain").canvas.flush_events() def step(self, a): # ns = s[:] # make copy of state so as not to affect original mid-step diff --git a/setup.py b/setup.py index 88279341..7a30f5cf 100644 --- a/setup.py +++ b/setup.py @@ -6,6 +6,7 @@ BSD license. """ import sys +import glob import multiprocessing try: from setuptools import setup, Command, find_packages @@ -238,8 +239,17 @@ def no_cythonize(extensions, **_ignore): zip_safe=False, cmdclass=cmdclass, long_description=open('README.rst').read(), - packages=find_packages(exclude=['tests', 'tests.*']), include_package_data=True, + data_files=[ + ('rlpy/Domains/GridWorldMaps', glob.glob('./rlpy/Domains/GridWorldMaps/*.txt')), + ('rlpy/Domains/IntruderMonitoringMaps', glob.glob('./rlpy/Domains/IntruderMonitoringMaps/*.txt')), + ('rlpy/Domains/SystemAdministratorMaps', glob.glob('./rlpy/Domains/SystemAdministratorMaps/*.txt')), + ('rlpy/Domains/PinballConfigs', glob.glob('./rlpy/Domains/PinballConfigs/*.cfg')), + ('rlpy/Domains/PacmanPackage/layouts', glob.glob('./rlpy/Domains/PacmanPackage/layouts/*.lay')), + ('rlpy/Policies', glob.glob('./rlpy/Policies/*.mat')) + ], + + packages=find_packages(exclude=['tests', 'tests.*']), install_requires=[ 'numpy >= 1.7', 'scipy', @@ -248,7 +258,8 @@ def no_cythonize(extensions, **_ignore): 'scikit-learn', 'joblib', 'hyperopt', - 'pymongo' + 'pymongo', + 'cairocffi' ], setup_requires=['numpy >= 1.7'], ext_modules=extensions, diff --git a/tests/test_cartpole.py b/tests/test_cartpole.py index e41657b7..50f9944c 100644 --- a/tests/test_cartpole.py +++ b/tests/test_cartpole.py @@ -30,8 +30,8 @@ def test_cartpole(): def check_traj(domain_class, filename): - with open(filename) as f: - traj = pickle.load(f) + with open(filename, 'rb') as f: + traj = pickle.load(f, encoding='latin1') traj_now = sample_random_trajectory(domain_class) for i, e1, e2 in zip(list(range(len(traj_now))), traj_now, traj): print(i) @@ -48,7 +48,7 @@ def check_traj(domain_class, filename): def save_trajectory(domain_class, filename): traj = sample_random_trajectory(domain_class) - with open(filename, "w") as f: + with open(filename, 'wb') as f: pickle.dump(traj, f)