Skip to content
Snippets Groups Projects
Commit ce24706b authored by Erik Nygren's avatar Erik Nygren
Browse files

included warning message if predictor or observation builder does not provide cells to render.

parent 525f1464
No related branches found
No related tags found
No related merge requests found
...@@ -84,7 +84,7 @@ for trials in range(1, n_trials + 1): ...@@ -84,7 +84,7 @@ for trials in range(1, n_trials + 1):
# Environment step which returns the observations for all agents, their corresponding # Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done # reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict) next_obs, all_rewards, done, _ = env.step(action_dict)
env_renderer.render_env(show=True, show_observations=False, show_predictions=True) env_renderer.render_env(show=True, show_observations=True, show_predictions=True)
# Update replay buffer and train agent # Update replay buffer and train agent
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
......
...@@ -163,7 +163,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): ...@@ -163,7 +163,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
# prediction is ready # prediction is ready
prediction[index] = [index, *new_position, new_direction, 0] prediction[index] = [index, *new_position, new_direction, 0]
visited.add((new_position[0], new_position[1], new_direction)) visited.add((new_position[0], new_position[1], new_direction))
# self.env.dev_pred_dict[agent.handle] = visited self.env.dev_pred_dict[agent.handle] = visited
prediction_dict[agent.handle] = prediction prediction_dict[agent.handle] = prediction
# cleanup: reset initial position # cleanup: reset initial position
......
...@@ -280,7 +280,9 @@ class RenderTool(object): ...@@ -280,7 +280,9 @@ class RenderTool(object):
# Check if the observation builder provides an observation # Check if the observation builder provides an observation
if len(observation_dict) < 1: if len(observation_dict) < 1:
warnings.warn("Observation Builder did not provide an observation_dict of all observed cells.") warnings.warn(
"Predictor did not provide any predicted cells to render. \
Observaiton builder needs to populate: env.dev_obs_dict")
else: else:
for agent in agent_handles: for agent in agent_handles:
color = self.gl.get_agent_color(agent) color = self.gl.get_agent_color(agent)
...@@ -299,7 +301,13 @@ class RenderTool(object): ...@@ -299,7 +301,13 @@ class RenderTool(object):
""" """
rt = self.__class__ rt = self.__class__
if len(prediction_dict) < 1: if len(prediction_dict) < 1:
warnings.warn("Predictor did not provide any predicted cells to render.") warnings.warn(
"Predictor did not provide any predicted cells to render. \
Predictors builder needs to populate: env.dev_pred_dict")
else: else:
for agent in agent_handles: for agent in agent_handles:
color = self.gl.get_agent_color(agent) color = self.gl.get_agent_color(agent)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment