Skip to content
Snippets Groups Projects
Commit 525f1464 authored by Erik Nygren's avatar Erik Nygren
Browse files

included warning message if predictor or observation builder does not provide cells to render.

parent 9b5f7626
No related branches found
No related tags found
No related merge requests found
......@@ -17,7 +17,7 @@ LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2)
env = RailEnv(width=50,
height=50,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
obs_builder_object=LocalGridObs,
obs_builder_object=TreeObservation,
number_of_agents=5)
env_renderer = RenderTool(env, gl="PILSVG", )
......@@ -84,7 +84,7 @@ for trials in range(1, n_trials + 1):
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict)
env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
env_renderer.render_env(show=True, show_observations=False, show_predictions=True)
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
......
......@@ -163,7 +163,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
# prediction is ready
prediction[index] = [index, *new_position, new_direction, 0]
visited.add((new_position[0], new_position[1], new_direction))
self.env.dev_pred_dict[agent.handle] = visited
# self.env.dev_pred_dict[agent.handle] = visited
prediction_dict[agent.handle] = prediction
# cleanup: reset initial position
......
import time
import warnings
from collections import deque
from enum import IntEnum
......@@ -276,12 +277,17 @@ class RenderTool(object):
"""
rt = self.__class__
for agent in agent_handles:
color = self.gl.get_agent_color(agent)
for visited_cell in observation_dict[agent]:
cell_coord = array(visited_cell[:2])
cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half
self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100)
# Check if the observation builder provides an observation
if len(observation_dict) < 1:
warnings.warn("Observation Builder did not provide an observation_dict of all observed cells.")
else:
for agent in agent_handles:
color = self.gl.get_agent_color(agent)
for visited_cell in observation_dict[agent]:
cell_coord = array(visited_cell[:2])
cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half
self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100)
def render_prediction(self, agent_handles, prediction_dict):
"""
......@@ -292,19 +298,22 @@ class RenderTool(object):
"""
rt = self.__class__
for agent in agent_handles:
color = self.gl.get_agent_color(agent)
for visited_cell in prediction_dict[agent]:
cell_coord = array(visited_cell[:2])
if type(self.gl) is PILSVG:
# TODO : Track highlighting (Adrian)
r = cell_coord[0]
c = cell_coord[1]
transitions = self.env.rail.grid[r, c]
self.gl.set_predicion_path_at(r, c, transitions, agent_rail_color=color)
else:
cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half
self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100)
if len(prediction_dict) < 1:
warnings.warn("Predictor did not provide any predicted cells to render.")
else:
for agent in agent_handles:
color = self.gl.get_agent_color(agent)
for visited_cell in prediction_dict[agent]:
cell_coord = array(visited_cell[:2])
if type(self.gl) is PILSVG:
# TODO : Track highlighting (Adrian)
r = cell_coord[0]
c = cell_coord[1]
transitions = self.env.rail.grid[r, c]
self.gl.set_predicion_path_at(r, c, transitions, agent_rail_color=color)
else:
cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half
self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100)
def render_rail(self, spacing=False, rail_color="gray", curves=True, arrows=False):
......@@ -558,7 +567,7 @@ class RenderTool(object):
if show_observations:
self.render_observation(range(env.get_num_agents()), env.dev_obs_dict)
if show_predictions and len(env.dev_pred_dict) > 0:
if show_predictions:
self.render_prediction(range(env.get_num_agents()), env.dev_pred_dict)
if show:
self.gl.show()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment