diff --git a/examples/training_example.py b/examples/training_example.py index a05f7c727ac9a56453951de453ea184cab1ea4fc..cfed6c92cc74c45445c436a65d15c9eb8292fe32 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -17,7 +17,7 @@ LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2) env = RailEnv(width=50, height=50, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), - obs_builder_object=LocalGridObs, + obs_builder_object=TreeObservation, number_of_agents=5) env_renderer = RenderTool(env, gl="PILSVG", ) @@ -84,7 +84,7 @@ for trials in range(1, n_trials + 1): # Environment step which returns the observations for all agents, their corresponding # reward and whether their are done next_obs, all_rewards, done, _ = env.step(action_dict) - env_renderer.render_env(show=True, show_observations=True, show_predictions=False) + env_renderer.render_env(show=True, show_observations=True, show_predictions=True) # Update replay buffer and train agent for a in range(env.get_num_agents()): diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 984357a166d10d5498c1f4cecd6d0def1f577f1d..65fffadf5309ebbcc02cd6a2b39427f2469fec04 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -1,4 +1,5 @@ import time +import warnings from collections import deque from enum import IntEnum @@ -276,12 +277,19 @@ class RenderTool(object): """ rt = self.__class__ - for agent in agent_handles: - color = self.gl.get_agent_color(agent) - for visited_cell in observation_dict[agent]: - cell_coord = array(visited_cell[:2]) - cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half - self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100) + + # Check if the observation builder provides an observation + if len(observation_dict) < 1: + warnings.warn( + "Predictor did not provide any predicted cells to render. \ + Observaiton builder needs to populate: env.dev_obs_dict") + else: + for agent in agent_handles: + color = self.gl.get_agent_color(agent) + for visited_cell in observation_dict[agent]: + cell_coord = array(visited_cell[:2]) + cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half + self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100) def render_prediction(self, agent_handles, prediction_dict): """ @@ -292,19 +300,28 @@ class RenderTool(object): """ rt = self.__class__ - for agent in agent_handles: - color = self.gl.get_agent_color(agent) - for visited_cell in prediction_dict[agent]: - cell_coord = array(visited_cell[:2]) - if type(self.gl) is PILSVG: - # TODO : Track highlighting (Adrian) - r = cell_coord[0] - c = cell_coord[1] - transitions = self.env.rail.grid[r, c] - self.gl.set_predicion_path_at(r, c, transitions, agent_rail_color=color) - else: - cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half - self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100) + if len(prediction_dict) < 1: + warnings.warn( + "Predictor did not provide any predicted cells to render. \ + Predictors builder needs to populate: env.dev_pred_dict") + + + + + else: + for agent in agent_handles: + color = self.gl.get_agent_color(agent) + for visited_cell in prediction_dict[agent]: + cell_coord = array(visited_cell[:2]) + if type(self.gl) is PILSVG: + # TODO : Track highlighting (Adrian) + r = cell_coord[0] + c = cell_coord[1] + transitions = self.env.rail.grid[r, c] + self.gl.set_predicion_path_at(r, c, transitions, agent_rail_color=color) + else: + cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half + self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100) def render_rail(self, spacing=False, rail_color="gray", curves=True, arrows=False): @@ -558,7 +575,7 @@ class RenderTool(object): if show_observations: self.render_observation(range(env.get_num_agents()), env.dev_obs_dict) - if show_predictions and len(env.dev_pred_dict) > 0: + if show_predictions: self.render_prediction(range(env.get_num_agents()), env.dev_pred_dict) if show: self.gl.show()