Skip to content
Snippets Groups Projects 5.55 KiB
Newer Older
import getopt
import random
import sys
import time
from typing import Optional, List, Dict
from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid_utils import coordinate_to_position
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
u214892's avatar
u214892 committed
from flatland.utils.misc import str2bool
u214892's avatar
u214892 committed
from flatland.utils.ordered_set import OrderedSet
from flatland.utils.rendertools import RenderTool


    We use the provided ShortestPathPredictor to illustrate the usage of predictors in your custom observation.

    def __init__(self, predictor):
        self.predictor = predictor

    def reset(self):
    def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, np.ndarray]:
        Because we do not want to call the predictor seperately for every agent we implement the get_many function
        Here we can call the predictor just ones for all the agents and use the predictions to generate our observations
        :param handles:

        self.predictions = self.predictor.get()

        if handles is None:
            handles = []

        for t in range(len(self.predictions[0])):
            pos_list = []
            for a in handles:
            # We transform (x,y) coodrinates to a single integer number for simpler comparison
            self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})

        observations = super().get_many(handles)

    def get(self, handle: int = 0) -> np.ndarray:
        Lets write a simple observation which just indicates whether or not the own predicted path
        overlaps with other predicted paths at any time. This is useless for the task of navigation but might
        help when looking for conflicts. A more complex implementation can be found in the TreeObsForRailEnv class

        Each agent recieves an observation of length 10, where each element represents a prediction step and its value
         - 0 if no overlap is happening
         - 1 where n i the number of other paths crossing the predicted cell

        :param handle: handeled as an index of an agent
        :return: Observation of handle

        observation = np.zeros(10)

        # We are going to track what cells where considered while building the obervation and make them accesible
        # For rendering

u214892's avatar
u214892 committed
        visited = OrderedSet()
        for _idx in range(10):
            # Check if any of the other prediction overlap with agents own predictions
            x_coord = self.predictions[handle][_idx][1]
            y_coord = self.predictions[handle][_idx][2]

            # We add every observed cell to the observation rendering
            visited.add((x_coord, y_coord))
            if self.predicted_pos[_idx][handle] in np.delete(self.predicted_pos[_idx], handle, 0):
                # We detect if another agent is predicting to pass through the same cell at the same predicted time
                observation[handle] = 1

        # This variable will be access by the renderer to visualize the observation
        self.env.dev_obs_dict[handle] = visited

        return observation

u229589's avatar
u229589 committed
    def set_env(self, env: Environment):
u229589's avatar
u229589 committed
        opts, args = getopt.getopt(args, "", ["sleep-for-animation=", ""])
    except getopt.GetoptError as err:
        print(str(err))  # will print something like "option -a not recognized"
        if o in ("--sleep-for-animation"):
u214892's avatar
u214892 committed
            sleep_for_animation = str2bool(a)
            assert False, "unhandled option"

    # Initiate the Predictor
    custom_predictor = ShortestPathPredictorForRailEnv(10)

    # Pass the Predictor to the observation builder
    custom_obs_builder = ObservePredictions(custom_predictor)

    # Initiate Environment
    env = RailEnv(width=10, height=10,
                  rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999,
                                                        seed=1), schedule_generator=complex_schedule_generator(),
                  number_of_agents=3, obs_builder_object=custom_obs_builder)
    obs, info = env.reset()
    env_renderer = RenderTool(env, gl="PILSVG")

    # We render the initial step and show the obsered cells as colored boxes
    env_renderer.render_env(show=True, frames=True, show_observations=True, show_predictions=False)

    action_dict = {}
    for step in range(100):
        for a in range(env.get_num_agents()):
            action = np.random.randint(0, 5)
            action_dict[a] = action
        obs, all_rewards, done, _ = env.step(action_dict)
        print("Rewards: ", all_rewards, "  [done=", done, "]")
        env_renderer.render_env(show=True, frames=True, show_observations=True, show_predictions=False)

if __name__ == '__main__':
    if 'argv' in globals():