From 9890abf53056a3b117d103a9aa7bddb0a31cebea Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Thu, 13 Jun 2019 23:25:22 +0200 Subject: [PATCH] added new prediction which follows shortest path updated test to handle new predictor --- examples/training_example.py | 4 +- flatland/envs/observations.py | 2 +- flatland/envs/predictions.py | 91 +++++++++++++++++++++++++++- tests/test_env_prediction_builder.py | 2 +- 4 files changed, 94 insertions(+), 5 deletions(-) diff --git a/examples/training_example.py b/examples/training_example.py index c785804..78336b9 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -2,7 +2,7 @@ import numpy as np from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv -from flatland.envs.predictions import DummyPredictorForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv np.random.seed(1) @@ -11,7 +11,7 @@ np.random.seed(1) # Training on simple small tasks is the best way to get familiar with the environment # -TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv()) +TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) env = RailEnv(width=20, height=20, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 4fa0e09..885a966 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -177,7 +177,7 @@ class TreeObsForRailEnv(ObservationBuilder): if self.predictor: self.predicted_pos = {} self.predicted_dir = {} - self.predictions = self.predictor.get() + self.predictions = self.predictor.get(self.distance_map) for t in range(len(self.predictions[0])): pos_list = [] diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index f7dec07..988594e 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -16,7 +16,7 @@ class DummyPredictorForRailEnv(PredictionBuilder): The prediction acts as if no other agent is in the environment and always takes the forward action. """ - def get(self, handle=None): + def get(self, distancemap, handle=None): """ Called whenever predict is called on the environment. @@ -72,3 +72,92 @@ class DummyPredictorForRailEnv(PredictionBuilder): agent.position = _agent_initial_position agent.direction = _agent_initial_direction return prediction_dict + + +class ShortestPathPredictorForRailEnv(PredictionBuilder): + """ + DummyPredictorForRailEnv object. + + This object returns predictions for agents in the RailEnv environment. + The prediction acts as if no other agent is in the environment and always takes the forward action. + """ + + def get(self, distancemap, handle=None): + """ + Called whenever predict is called on the environment. + + Parameters + ------- + handle : int (optional) + Handle of the agent for which to compute the observation vector. + + Returns + ------- + function + Returns a dictionary index by the agent handle and for each agent a vector of 5 elements: + - time_offset + - position axis 0 + - position axis 1 + - direction + - action taken to come here + """ + agents = self.env.agents + if handle: + agents = [self.env.agents[handle]] + + prediction_dict = {} + agent_idx = 0 + for agent in agents: + action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT] + _agent_initial_position = agent.position + _agent_initial_direction = agent.direction + prediction = np.zeros(shape=(self.max_depth + 1, 5)) + prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0] + for index in range(1, self.max_depth + 1): + action_done = False + # if we're at the target, stop moving... + if agent.position == agent.target: + prediction[index] = [index, agent.target[0], agent.target[1], agent.direction, + RailEnvActions.STOP_MOVING] + + continue + # Take shortest possible path + cell_transitions = self.env.rail.get_transitions((*agent.position, agent.direction)) + + if np.sum(cell_transitions) == 1: + new_direction = np.argmax(cell_transitions) + new_position = self._new_position(agent.position, new_direction) + else: + for direct in range(4): + min_dist = np.inf + target_dist = distancemap[agent_idx, agent.position[0], agent.position[1], direct] + if target_dist < min_dist: + min_dist = target_dist + new_direction = direct + new_position = self._new_position(agent.position, new_direction) + + agent.position = new_position + agent.direction = new_direction + prediction[index] = [index, new_position[0], new_position[1], new_direction, 0] + action_done = True + if not action_done: + raise Exception("Cannot move further. Something is wrong") + prediction_dict[agent.handle] = prediction + agent.position = _agent_initial_position + agent.direction = _agent_initial_direction + agent_idx += 1 + + return prediction_dict + + def _new_position(self, position, movement): + """ + Utility function that converts a compass movement over a 2D grid to new positions (r, c). + """ + if movement == 0: # NORTH + return (position[0] - 1, position[1]) + elif movement == 1: # EAST + return (position[0], position[1] + 1) + elif movement == 2: # SOUTH + return (position[0] + 1, position[1]) + elif movement == 3: # WEST + return (position[0], position[1] - 1) diff --git a/tests/test_env_prediction_builder.py b/tests/test_env_prediction_builder.py index cb7d26d..5f5cea3 100644 --- a/tests/test_env_prediction_builder.py +++ b/tests/test_env_prediction_builder.py @@ -74,7 +74,7 @@ def test_predictions(): env.agents[0].direction = 0 env.agents[0].target = (3., 0.) - predictions = env.obs_builder.predictor.get() + predictions = env.obs_builder.predictor.get(None) positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0]))) directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0]))) time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0]))) -- GitLab