diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py index 3c35b30a3be4e0c1563b09c1c86359295226079e..723bb1102092c7d48bd938bbd60d0c5213ffecf6 100644 --- a/examples/custom_observation_example.py +++ b/examples/custom_observation_example.py @@ -1,10 +1,13 @@ import random import time + import numpy as np -from flatland.envs.observations import TreeObsForRailEnv from flatland.core.env_observation_builder import ObservationBuilder +from flatland.core.grid.grid_utils import coordinate_to_position from flatland.envs.generators import random_rail_generator, complex_rail_generator +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool @@ -86,20 +89,11 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): env = RailEnv(width=7, height=7, - rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), - number_of_agents=2, + rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0), + number_of_agents=1, obs_builder_object=SingleAgentNavigationObs()) -obs, all_rewards, done, _ = env.step({0: 0, 1: 1}) -for i in range(env.get_num_agents()): - print(obs[i]) - -env = RailEnv(width=50, - height=50, - rail_generator=random_rail_generator(), - number_of_agents=1, - obs_builder_object=SingleAgentNavigationObs()) -obs, all_rewards, done, _ = env.step({0: 0}) +obs = env.reset() env_renderer = RenderTool(env, gl="PILSVG") env_renderer.render_env(show=True, frames=True, show_observations=True) for step in range(100): @@ -108,5 +102,119 @@ for step in range(100): print("Rewards: ", all_rewards, " [done=", done, "]") env_renderer.render_env(show=True, frames=True, show_observations=True) time.sleep(0.1) + if done["__all__"]: + break +env_renderer.close_window() + + +class ObservePredictions(TreeObsForRailEnv): + """ + We use the provided ShortestPathPredictor to illustrate the usage of predictors in your custom observation. + + We derive our observation builder from TreeObsForRailEnv, to exploit the existing implementation to compute + the minimum distances from each grid node to each agent's target. + This is necessary so that we can pass the distance map to the ShortestPathPredictor + Here we also want to highlight how you can visualize your observation + """ + + def __init__(self, predictor): + super().__init__(max_depth=0) + self.observation_space = [10] + self.predictor = predictor + + def reset(self): + # Recompute the distance map, if the environment has changed. + super().reset() + + def get_many(self, handles=None): + ''' + Because we do not want to call the predictor seperately for every agent we implement the get_many function + Here we can call the predictor just ones for all the agents and use the predictions to generate our observations + :param handles: + :return: + ''' + + self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map}) + + self.predicted_pos = {} + for t in range(len(self.predictions[0])): + pos_list = [] + for a in handles: + pos_list.append(self.predictions[a][t][1:3]) + # We transform (x,y) coodrinates to a single integer number for simpler comparison + self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)}) + observations = {} + + # Collect all the different observation for all the agents + for h in handles: + observations[h] = self.get(h) + return observations + + def get(self, handle): + ''' + Lets write a simple observation which just indicates whether or not the own predicted path + overlaps with other predicted paths at any time. This is useless for the task of navigation but might + help when looking for conflicts. A more complex implementation can be found in the TreeObsForRailEnv class + + Each agent recieves an observation of length 10, where each element represents a prediction step and its value + is: + - 0 if no overlap is happening + - 1 where n i the number of other paths crossing the predicted cell + + :param handle: handeled as an index of an agent + :return: Observation of handle + ''' + + observation = np.zeros(10) + + # We are going to track what cells where considered while building the obervation and make them accesible + # For rendering + + visited = set() + for _idx in range(10): + # Check if any of the other prediction overlap with agents own predictions + x_coord = self.predictions[handle][_idx][1] + y_coord = self.predictions[handle][_idx][2] + + # We add every observed cell to the observation rendering + visited.add((x_coord, y_coord)) + if self.predicted_pos[_idx][handle] in np.delete(self.predicted_pos[_idx], handle, 0): + # We detect if another agent is predicting to pass through the same cell at the same predicted time + observation[handle] = 1 + + # This variable will be access by the renderer to visualize the observation + self.env.dev_obs_dict[handle] = visited + + return observation + + +# Initiate the Predictor +CustomPredictor = ShortestPathPredictorForRailEnv(10) + +# Pass the Predictor to the observation builder +CustomObsBuilder = ObservePredictions(CustomPredictor) + +# Initiate Environment +env = RailEnv(width=10, + height=10, + rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, seed=0), + number_of_agents=3, + obs_builder_object=CustomObsBuilder) + +obs = env.reset() +env_renderer = RenderTool(env, gl="PILSVG") + +# We render the initial step and show the obsered cells as colored boxes +env_renderer.render_env(show=True, frames=True, show_observations=True, show_predictions=False) + +action_dict = {} +for step in range(100): + for a in range(env.get_num_agents()): + action = np.random.randint(0, 5) + action_dict[a] = action + obs, all_rewards, done, _ = env.step(action_dict) + print("Rewards: ", all_rewards, " [done=", done, "]") + env_renderer.render_env(show=True, frames=True, show_observations=True, show_predictions=False) + time.sleep(0.5) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index e6646cb9b7f4325e73496c04dd8bdf837c34fe30..2c9a747372ae74bb2f9e286d0d1fc200260d7f01 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -308,6 +308,7 @@ class TreeObsForRailEnv(ObservationBuilder): We walk along the branch and collect the information documented in the get() function. If there is a branching point a new node is created and each possible branch is explored. """ + # [Recursive branch opened] if depth >= self.max_depth + 1: return [], [] diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 1d825ff16554ef859339f0a60f1dbbce28dc62ef..88a79ea72606e7c1b46f92f6d73429979d67a4e6 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -86,6 +86,10 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): The prediction acts as if no other agent is in the environment and always takes the forward action. """ + def __init__(self, max_depth=20): + # Initialize with depth 20 + self.max_depth = max_depth + def get(self, custom_args=None, handle=None): """ Called whenever get_many in the observation build is called. diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 65fffadf5309ebbcc02cd6a2b39427f2469fec04..8974126ad69e35edbd621ba634727120720c9506 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -304,10 +304,6 @@ class RenderTool(object): warnings.warn( "Predictor did not provide any predicted cells to render. \ Predictors builder needs to populate: env.dev_pred_dict") - - - - else: for agent in agent_handles: color = self.gl.get_agent_color(agent)