Skip to content
Snippets Groups Projects
Commit f9b16819 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Merge branch '132_catch_exception_observation_rendering' into 'master'

132 catch exception observation rendering

See merge request flatland/flatland!133
parents 9b5f7626 ce24706b
No related branches found
Tags v0.3.5
No related merge requests found
......@@ -17,7 +17,7 @@ LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2)
env = RailEnv(width=50,
height=50,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
obs_builder_object=LocalGridObs,
obs_builder_object=TreeObservation,
number_of_agents=5)
env_renderer = RenderTool(env, gl="PILSVG", )
......@@ -84,7 +84,7 @@ for trials in range(1, n_trials + 1):
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict)
env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
env_renderer.render_env(show=True, show_observations=True, show_predictions=True)
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
......
import time
import warnings
from collections import deque
from enum import IntEnum
......@@ -276,12 +277,19 @@ class RenderTool(object):
"""
rt = self.__class__
for agent in agent_handles:
color = self.gl.get_agent_color(agent)
for visited_cell in observation_dict[agent]:
cell_coord = array(visited_cell[:2])
cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half
self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100)
# Check if the observation builder provides an observation
if len(observation_dict) < 1:
warnings.warn(
"Predictor did not provide any predicted cells to render. \
Observaiton builder needs to populate: env.dev_obs_dict")
else:
for agent in agent_handles:
color = self.gl.get_agent_color(agent)
for visited_cell in observation_dict[agent]:
cell_coord = array(visited_cell[:2])
cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half
self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100)
def render_prediction(self, agent_handles, prediction_dict):
"""
......@@ -292,19 +300,28 @@ class RenderTool(object):
"""
rt = self.__class__
for agent in agent_handles:
color = self.gl.get_agent_color(agent)
for visited_cell in prediction_dict[agent]:
cell_coord = array(visited_cell[:2])
if type(self.gl) is PILSVG:
# TODO : Track highlighting (Adrian)
r = cell_coord[0]
c = cell_coord[1]
transitions = self.env.rail.grid[r, c]
self.gl.set_predicion_path_at(r, c, transitions, agent_rail_color=color)
else:
cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half
self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100)
if len(prediction_dict) < 1:
warnings.warn(
"Predictor did not provide any predicted cells to render. \
Predictors builder needs to populate: env.dev_pred_dict")
else:
for agent in agent_handles:
color = self.gl.get_agent_color(agent)
for visited_cell in prediction_dict[agent]:
cell_coord = array(visited_cell[:2])
if type(self.gl) is PILSVG:
# TODO : Track highlighting (Adrian)
r = cell_coord[0]
c = cell_coord[1]
transitions = self.env.rail.grid[r, c]
self.gl.set_predicion_path_at(r, c, transitions, agent_rail_color=color)
else:
cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half
self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100)
def render_rail(self, spacing=False, rail_color="gray", curves=True, arrows=False):
......@@ -558,7 +575,7 @@ class RenderTool(object):
if show_observations:
self.render_observation(range(env.get_num_agents()), env.dev_obs_dict)
if show_predictions and len(env.dev_pred_dict) > 0:
if show_predictions:
self.render_prediction(range(env.get_num_agents()), env.dev_pred_dict)
if show:
self.gl.show()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment