From 576715cbe2dd4dbd9f93237390ac3a1542ebb9a4 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Tue, 23 Apr 2019 09:20:22 +0200 Subject: [PATCH] updated rendering method for debugging --- examples/training_navigation.py | 15 +++------------ flatland/utils/rendertools.py | 9 ++++++--- 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/examples/training_navigation.py b/examples/training_navigation.py index 7b35cb21..60dc1adb 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -8,16 +8,7 @@ import torch,random random.seed(1) np.random.seed(1) -""" -transition_probability = [1.0, # empty cell - Case 0 - 3.0, # Case 1 - straight - 1.0, # Case 2 - simple switch - 3.0, # Case 3 - diamond drossing - 2.0, # Case 4 - single slip - 1.0, # Case 5 - double slip - 1.0, # Case 6 - symmetrical - 1.0] # Case 7 - dead end -""" + # Example generate a rail given a manual specification, # a map of tuples (cell_type, rotation) transition_probability = [1.0, # empty cell - Case 0 @@ -63,6 +54,7 @@ for trials in range(1, n_trials + 1): # Run episode for step in range(100): + #env_renderer.renderEnv(show=True) # Action for a in range(env.number_of_agents): @@ -73,7 +65,6 @@ for trials in range(1, n_trials + 1): next_obs, all_rewards, done, _ = env.step(action_dict) - # Update replay buffer and train agent for a in range(env.number_of_agents): agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]) @@ -81,7 +72,7 @@ for trials in range(1, n_trials + 1): obs = next_obs.copy() - if all(done): + if done['__all__']: env_done = 1 break # Epsioln decay diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 5365800b..24f36582 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -403,10 +403,9 @@ class RenderTool(object): # cell_size is a bit pointless with matplotlib - it does not relate to pixels, # so for now I've changed it to 1 (from 10) cell_size = 1 - + plt.clf() # if oFigure is None: # oFigure = plt.figure() - def drawTrans(oFrom, oTo, sColor="gray"): plt.plot( [oFrom[0], oTo[0]], # x @@ -551,7 +550,11 @@ class RenderTool(object): plt.xlim([0, env.width * cell_size]) plt.ylim([-env.height * cell_size, 0]) if show: - plt.show() + plt.show(block=False) + plt.pause(0.00001) + return + + def _draw_square(self, center, size, color): x0 = center[0]-size/2 -- GitLab