diff --git a/flatland/core/env.py b/flatland/core/env.py index 3c25beea236945b1728959e02ea07f6c0ba7a6ac..32691f507f4cb5586f10b5645cc22ece718edc21 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -84,6 +84,21 @@ class Environment: """ raise NotImplementedError() + def predict(self): + """ + Predictions step. + + Returns predictions for the agents. + The returns are dicts mapping from agent_id strings to values. + + Returns + ------- + predictions : dict + New predictions for each ready agent. + + """ + raise NotImplementedError() + def render(self): """ Perform rendering of the environment. diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index 3cef545c1658e6bfe2a292ee26c3e665ce6a5abc..f85afee4b625e59374c6cce266bf55b21e7fdb84 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -19,7 +19,6 @@ class ObservationBuilder: def __init__(self): self.observation_space = () - pass def _set_env(self, env): self.env = env diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..9f5e4dc5033ba3a789313b47d94b033251cb8276 --- /dev/null +++ b/flatland/core/env_prediction_builder.py @@ -0,0 +1,44 @@ +""" +PredictionBuilder objects are objects that can be passed to environments designed for customizability. +The PredictionBuilder-derived custom classes implement 2 functions, reset() and get([handle]). +If predictions are not required in every step or not for all agents, then + ++ Reset() is called after each environment reset, to allow for pre-computing relevant data. + ++ Get() is called whenever an step has to be computed, potentially for each agent independently in +case of multi-agent environments. +""" + + +class PredictionBuilder: + """ + PredictionBuilder base class. + """ + + def __init__(self, max_depth: int = 20): + self.max_depth = max_depth + + def _set_env(self, env): + self.env = env + + def reset(self): + """ + Called after each environment reset. + """ + pass + + def get(self, handle=0): + """ + Called whenever step_prediction is called on the environment. + + Parameters + ------- + handle : int (optional) + Handle of the agent for which to compute the observation vector. + + Returns + ------- + function + A prediction structure, specific to the corresponding environment. + """ + raise NotImplementedError() diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py new file mode 100644 index 0000000000000000000000000000000000000000..95c1a984c4151a9a873deeeb29438b290bb4f77e --- /dev/null +++ b/flatland/envs/predictions.py @@ -0,0 +1,72 @@ +""" +Collection of environment-specific PredictionBuilder. +""" + +import numpy as np + +from flatland.core.env_prediction_builder import PredictionBuilder + + +class DummyPredictorForRailEnv(PredictionBuilder): + """ + DummyPredictorForRailEnv object. + + This object returns predictions for agents in the RailEnv environment. + The prediction acts as if no other agent is in the environment and always takes the forward action. + """ + + def get(self, handle=None): + """ + Called whenever step_prediction is called on the environment. + + Parameters + ------- + handle : int (optional) + Handle of the agent for which to compute the observation vector. + + Returns + ------- + function + Returns a dictionary index by the agent handle and for each agent a vector of 5 elements: + - time_offset + - position axis 0 + - position axis 1 + - direction + - action taken to come here + """ + agents = self.env.agents + if handle: + agents = [self.env.agents[handle]] + + prediction_dict = {} + + for agent in agents: + + # 0: do nothing + # 1: turn left and move to the next cell + # 2: move to the next cell in front of the agent + # 3: turn right and move to the next cell + action_priorities = [2, 1, 3] + _agent_initial_position = agent.position + _agent_initial_direction = agent.direction + prediction = np.zeros(shape=(self.max_depth, 5)) + prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0] + for index in range(1, self.max_depth): + action_done = False + for action in action_priorities: + cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = self.env._check_action_on_agent(action, + agent) + if all([new_cell_isValid, transition_isValid]): + # move and change direction to face the new_direction that was + # performed + agent.position = new_position + agent.direction = new_direction + prediction[index] = [index, new_position[0], new_position[1], new_direction, action] + action_done = True + break + if not action_done: + print("Cannot move further.") + prediction_dict[agent.handle] = prediction + agent.position = _agent_initial_position + agent.direction = _agent_initial_direction + return prediction_dict diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index a23b6ac71651e8e424bb90a23cbcfd472ce89a12..190a14eae6ae2c529915c29fbc22d4d12ab4c955 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -51,7 +51,9 @@ class RailEnv(Environment): height, rail_generator=random_rail_generator(), number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2)): + obs_builder_object=TreeObsForRailEnv(max_depth=2), + prediction_builder_object=None + ): """ Environment init. @@ -94,6 +96,11 @@ class RailEnv(Environment): self.obs_builder = obs_builder_object self.obs_builder._set_env(self) + self.prediction_builder = prediction_builder_object + if self.prediction_builder: + self.prediction_builder._set_env(self) + + self.action_space = [1] self.observation_space = self.obs_builder.observation_space # updated on resets? @@ -212,51 +219,9 @@ class RailEnv(Environment): return if action > 0: - # pos = agent.position # self.agents_position[i] - # direction = agent.direction # self.agents_direction[i] - - # compute number of possible transitions in the current - # cell used to check for invalid actions - - new_direction, transition_isValid = self.check_action(agent, action) - - new_position = get_new_position(agent.position, new_direction) - # Is it a legal move? - # 1) transition allows the new_direction in the cell, - # 2) the new cell is not empty (case 0), - # 3) the cell is free, i.e., no agent is currently in that cell - - # if ( - # new_position[1] >= self.width or - # new_position[0] >= self.height or - # new_position[0] < 0 or new_position[1] < 0): - # new_cell_isValid = False - - # if self.rail.get_transitions(new_position) == 0: - # new_cell_isValid = False - - new_cell_isValid = ( - np.array_equal( # Check the new position is still in the grid - new_position, - np.clip(new_position, [0, 0], [self.height - 1, self.width - 1])) - and # check the new position has some transitions (ie is not an empty cell) - self.rail.get_transitions(new_position) > 0) - - # If transition validity hasn't been checked yet. - if transition_isValid is None: - transition_isValid = self.rail.get_transition( - (*agent.position, agent.direction), - new_direction) - - # cell_isFree = True - # for j in range(self.number_of_agents): - # if self.agents_position[j] == new_position: - # cell_isFree = False - # break - # Check the new position is not the same as any of the existing agent positions - # (including itself, for simplicity, since it is moving) - cell_isFree = not np.any( - np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) + cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = self._check_action_on_agent(action, + agent, + transition_isValid) if all([new_cell_isValid, transition_isValid, cell_isFree]): # move and change direction to face the new_direction that was @@ -296,6 +261,52 @@ class RailEnv(Environment): self.actions = [0] * self.get_num_agents() return self._get_observations(), self.rewards_dict, self.dones, {} + def _check_action_on_agent(self, action, agent): + # pos = agent.position # self.agents_position[i] + # direction = agent.direction # self.agents_direction[i] + # compute number of possible transitions in the current + # cell used to check for invalid actions + new_direction, transition_isValid = self.check_action(agent, action) + new_position = get_new_position(agent.position, new_direction) + # Is it a legal move? + # 1) transition allows the new_direction in the cell, + # 2) the new cell is not empty (case 0), + # 3) the cell is free, i.e., no agent is currently in that cell + # if ( + # new_position[1] >= self.width or + # new_position[0] >= self.height or + # new_position[0] < 0 or new_position[1] < 0): + # new_cell_isValid = False + # if self.rail.get_transitions(new_position) == 0: + # new_cell_isValid = False + new_cell_isValid = ( + np.array_equal( # Check the new position is still in the grid + new_position, + np.clip(new_position, [0, 0], [self.height - 1, self.width - 1])) + and # check the new position has some transitions (ie is not an empty cell) + self.rail.get_transitions(new_position) > 0) + # If transition validity hasn't been checked yet. + if transition_isValid is None: + transition_isValid = self.rail.get_transition( + (*agent.position, agent.direction), + new_direction) + # cell_isFree = True + # for j in range(self.number_of_agents): + # if self.agents_position[j] == new_position: + # cell_isFree = False + # break + # Check the new position is not the same as any of the existing agent positions + # (including itself, for simplicity, since it is moving) + cell_isFree = not np.any( + np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) + return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid + + def predict(self): + if not self.prediction_builder: + return {} + return self.prediction_builder.get() + + def check_action(self, agent, action): transition_isValid = None possible_transitions = self.rail.get_transitions((*agent.position, agent.direction)) @@ -332,6 +343,11 @@ class RailEnv(Environment): self.obs_dict[iAgent] = self.obs_builder.get(iAgent) return self.obs_dict + def _get_predictions(self): + if not self.prediction_builder: + return {} + return {} + def render(self): # TODO: pass diff --git a/tests/test_env_prediction_builder.py b/tests/test_env_prediction_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..35a6a27b970ce54e1cabd3cf8c80d30a34800a25 --- /dev/null +++ b/tests/test_env_prediction_builder.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import numpy as np + +from flatland.core.transition_map import GridTransitionMap, Grid4Transitions +from flatland.envs.generators import rail_from_GridTransitionMap_generator +from flatland.envs.observations import GlobalObsForRailEnv +from flatland.envs.predictions import DummyPredictorForRailEnv +from flatland.envs.rail_env import RailEnv + +"""Test predictions for `flatland` package.""" + + +def test_predictions(): + # We instantiate a very simple rail network on a 7x10 grid: + # | + # | + # | + # _ _ _ /_\ _ _ _ _ _ _ + # \ / + # | + # | + # | + + cells = [int('0000000000000000', 2), # empty cell - Case 0 + int('1000000000100000', 2), # Case 1 - straight + int('1001001000100000', 2), # Case 2 - simple switch + int('1000010000100001', 2), # Case 3 - diamond drossing + int('1001011000100001', 2), # Case 4 - single slip switch + int('1100110000110011', 2), # Case 5 - double slip switch + int('0101001000000010', 2), # Case 6 - symmetrical switch + int('0010000000000000', 2)] # Case 7 - dead end + + transitions = Grid4Transitions([]) + empty = cells[0] + + dead_end_from_south = cells[7] + dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) + dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180) + dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) + + vertical_straight = cells[1] + horizontal_straight = transitions.rotate_transition(vertical_straight, 90) + + double_switch_south_horizontal_straight = horizontal_straight + cells[6] + double_switch_north_horizontal_straight = transitions.rotate_transition( + double_switch_south_horizontal_straight, 180) + + rail_map = np.array( + [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + + [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + + [[dead_end_from_east] + [horizontal_straight] * 2 + + [double_switch_north_horizontal_straight] + + [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] + + [horizontal_straight] * 2 + [dead_end_from_west]] + + [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + + [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) + + rail = GridTransitionMap(width=rail_map.shape[1], + height=rail_map.shape[0], transitions=transitions) + rail.grid = rail_map + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_GridTransitionMap_generator(rail), + number_of_agents=1, + obs_builder_object=GlobalObsForRailEnv(), + prediction_builder_object=DummyPredictorForRailEnv(max_depth=20) + ) + + env.reset() + + # set initial position and direction for testing... + env.agents[0].position = (5, 6) + env.agents[0].direction = 0 + + predictions = env.predict() + positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0]))) + directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0]))) + time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0]))) + actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0]))) + + # compare against expected values + expected_positions = np.array([[5., 6.], + [4., 6.], + [3., 6.], + [3., 5.], + [3., 4.], + [3., 3.], + [3., 2.], + [3., 1.], + [3., 0.], + [3., 1.], + [3., 2.], + [3., 3.], + [3., 4.], + [3., 5.], + [3., 6.], + [3., 7.], + [3., 8.], + [3., 9.], + [3., 8.], + [3., 7.]]) + expected_directions = np.array([[0.], + [0.], + [0.], + [3.], + [3.], + [3.], + [3.], + [3.], + [3.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [3.], + [3.]]) + expected_time_offsets = np.array([[0.], + [1.], + [2.], + [3.], + [4.], + [5.], + [6.], + [7.], + [8.], + [9.], + [10.], + [11.], + [12.], + [13.], + [14.], + [15.], + [16.], + [17.], + [18.], + [19.]]) + expected_actions = np.array([[0.], + [2.], + [2.], + [1.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.]]) + assert np.array_equal(positions, expected_positions) + assert np.array_equal(directions, expected_directions) + assert np.array_equal(time_offsets, expected_time_offsets) + assert np.array_equal(actions, expected_actions) + + +def main(): + test_predictions() + + +if __name__ == "__main__": + main()