From 960361f1434461bfc3e5f443157d74ed2673ee1f Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Tue, 4 Jun 2019 11:38:20 +0200 Subject: [PATCH] 25 predictor draft --- flatland/core/env_prediction_builder.py | 11 ++- flatland/envs/predictions.py | 61 ++++++++++++-- flatland/envs/rail_env.py | 88 ++++++++++---------- tests/test_env_prediction_builder.py | 102 +++++++++++++++++++++++- 4 files changed, 201 insertions(+), 61 deletions(-) diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py index 321ede95..9f5e4dc5 100644 --- a/flatland/core/env_prediction_builder.py +++ b/flatland/core/env_prediction_builder.py @@ -15,8 +15,8 @@ class PredictionBuilder: PredictionBuilder base class. """ - def __init__(self): - pass + def __init__(self, max_depth: int = 20): + self.max_depth = max_depth def _set_env(self, env): self.env = env @@ -25,12 +25,11 @@ class PredictionBuilder: """ Called after each environment reset. """ - raise NotImplementedError() + pass def get(self, handle=0): """ - Called whenever an observation has to be computed for the `env' environment, possibly - for each agent independently (agent id `handle'). + Called whenever step_prediction is called on the environment. Parameters ------- @@ -40,6 +39,6 @@ class PredictionBuilder: Returns ------- function - An prediction structure, specific to the corresponding environment. + A prediction structure, specific to the corresponding environment. """ raise NotImplementedError() diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index e57bf753..95c1a984 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -2,6 +2,8 @@ Collection of environment-specific PredictionBuilder. """ +import numpy as np + from flatland.core.env_prediction_builder import PredictionBuilder @@ -13,11 +15,58 @@ class DummyPredictorForRailEnv(PredictionBuilder): The prediction acts as if no other agent is in the environment and always takes the forward action. """ - def __init__(self): - pass + def get(self, handle=None): + """ + Called whenever step_prediction is called on the environment. + + Parameters + ------- + handle : int (optional) + Handle of the agent for which to compute the observation vector. + + Returns + ------- + function + Returns a dictionary index by the agent handle and for each agent a vector of 5 elements: + - time_offset + - position axis 0 + - position axis 1 + - direction + - action taken to come here + """ + agents = self.env.agents + if handle: + agents = [self.env.agents[handle]] + + prediction_dict = {} - def reset(self): - pass + for agent in agents: - def get(self, handle=0): - return {} + # 0: do nothing + # 1: turn left and move to the next cell + # 2: move to the next cell in front of the agent + # 3: turn right and move to the next cell + action_priorities = [2, 1, 3] + _agent_initial_position = agent.position + _agent_initial_direction = agent.direction + prediction = np.zeros(shape=(self.max_depth, 5)) + prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0] + for index in range(1, self.max_depth): + action_done = False + for action in action_priorities: + cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = self.env._check_action_on_agent(action, + agent) + if all([new_cell_isValid, transition_isValid]): + # move and change direction to face the new_direction that was + # performed + agent.position = new_position + agent.direction = new_direction + prediction[index] = [index, new_position[0], new_position[1], new_direction, action] + action_done = True + break + if not action_done: + print("Cannot move further.") + prediction_dict[agent.handle] = prediction + agent.position = _agent_initial_position + agent.direction = _agent_initial_direction + return prediction_dict diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 6ac31878..190a14ea 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -219,51 +219,9 @@ class RailEnv(Environment): return if action > 0: - # pos = agent.position # self.agents_position[i] - # direction = agent.direction # self.agents_direction[i] - - # compute number of possible transitions in the current - # cell used to check for invalid actions - - new_direction, transition_isValid = self.check_action(agent, action) - - new_position = get_new_position(agent.position, new_direction) - # Is it a legal move? - # 1) transition allows the new_direction in the cell, - # 2) the new cell is not empty (case 0), - # 3) the cell is free, i.e., no agent is currently in that cell - - # if ( - # new_position[1] >= self.width or - # new_position[0] >= self.height or - # new_position[0] < 0 or new_position[1] < 0): - # new_cell_isValid = False - - # if self.rail.get_transitions(new_position) == 0: - # new_cell_isValid = False - - new_cell_isValid = ( - np.array_equal( # Check the new position is still in the grid - new_position, - np.clip(new_position, [0, 0], [self.height - 1, self.width - 1])) - and # check the new position has some transitions (ie is not an empty cell) - self.rail.get_transitions(new_position) > 0) - - # If transition validity hasn't been checked yet. - if transition_isValid is None: - transition_isValid = self.rail.get_transition( - (*agent.position, agent.direction), - new_direction) - - # cell_isFree = True - # for j in range(self.number_of_agents): - # if self.agents_position[j] == new_position: - # cell_isFree = False - # break - # Check the new position is not the same as any of the existing agent positions - # (including itself, for simplicity, since it is moving) - cell_isFree = not np.any( - np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) + cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = self._check_action_on_agent(action, + agent, + transition_isValid) if all([new_cell_isValid, transition_isValid, cell_isFree]): # move and change direction to face the new_direction that was @@ -303,6 +261,46 @@ class RailEnv(Environment): self.actions = [0] * self.get_num_agents() return self._get_observations(), self.rewards_dict, self.dones, {} + def _check_action_on_agent(self, action, agent): + # pos = agent.position # self.agents_position[i] + # direction = agent.direction # self.agents_direction[i] + # compute number of possible transitions in the current + # cell used to check for invalid actions + new_direction, transition_isValid = self.check_action(agent, action) + new_position = get_new_position(agent.position, new_direction) + # Is it a legal move? + # 1) transition allows the new_direction in the cell, + # 2) the new cell is not empty (case 0), + # 3) the cell is free, i.e., no agent is currently in that cell + # if ( + # new_position[1] >= self.width or + # new_position[0] >= self.height or + # new_position[0] < 0 or new_position[1] < 0): + # new_cell_isValid = False + # if self.rail.get_transitions(new_position) == 0: + # new_cell_isValid = False + new_cell_isValid = ( + np.array_equal( # Check the new position is still in the grid + new_position, + np.clip(new_position, [0, 0], [self.height - 1, self.width - 1])) + and # check the new position has some transitions (ie is not an empty cell) + self.rail.get_transitions(new_position) > 0) + # If transition validity hasn't been checked yet. + if transition_isValid is None: + transition_isValid = self.rail.get_transition( + (*agent.position, agent.direction), + new_direction) + # cell_isFree = True + # for j in range(self.number_of_agents): + # if self.agents_position[j] == new_position: + # cell_isFree = False + # break + # Check the new position is not the same as any of the existing agent positions + # (including itself, for simplicity, since it is moving) + cell_isFree = not np.any( + np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) + return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid + def predict(self): if not self.prediction_builder: return {} diff --git a/tests/test_env_prediction_builder.py b/tests/test_env_prediction_builder.py index a43f7883..35a6a27b 100644 --- a/tests/test_env_prediction_builder.py +++ b/tests/test_env_prediction_builder.py @@ -3,13 +3,13 @@ import numpy as np -from flatland.envs.observations import GlobalObsForRailEnv from flatland.core.transition_map import GridTransitionMap, Grid4Transitions +from flatland.envs.generators import rail_from_GridTransitionMap_generator +from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv from flatland.envs.rail_env import RailEnv -from flatland.envs.generators import rail_from_GridTransitionMap_generator -"""Tests for `flatland` package.""" +"""Test predictions for `flatland` package.""" def test_predictions(): @@ -65,12 +65,106 @@ def test_predictions(): rail_generator=rail_from_GridTransitionMap_generator(rail), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv(), - prediction_builder_object=DummyPredictorForRailEnv() + prediction_builder_object=DummyPredictorForRailEnv(max_depth=20) ) env.reset() + # set initial position and direction for testing... + env.agents[0].position = (5, 6) + env.agents[0].direction = 0 + predictions = env.predict() + positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0]))) + directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0]))) + time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0]))) + actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0]))) + + # compare against expected values + expected_positions = np.array([[5., 6.], + [4., 6.], + [3., 6.], + [3., 5.], + [3., 4.], + [3., 3.], + [3., 2.], + [3., 1.], + [3., 0.], + [3., 1.], + [3., 2.], + [3., 3.], + [3., 4.], + [3., 5.], + [3., 6.], + [3., 7.], + [3., 8.], + [3., 9.], + [3., 8.], + [3., 7.]]) + expected_directions = np.array([[0.], + [0.], + [0.], + [3.], + [3.], + [3.], + [3.], + [3.], + [3.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [3.], + [3.]]) + expected_time_offsets = np.array([[0.], + [1.], + [2.], + [3.], + [4.], + [5.], + [6.], + [7.], + [8.], + [9.], + [10.], + [11.], + [12.], + [13.], + [14.], + [15.], + [16.], + [17.], + [18.], + [19.]]) + expected_actions = np.array([[0.], + [2.], + [2.], + [1.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.]]) + assert np.array_equal(positions, expected_positions) + assert np.array_equal(directions, expected_directions) + assert np.array_equal(time_offsets, expected_time_offsets) + assert np.array_equal(actions, expected_actions) def main(): -- GitLab