diff --git a/flatland/core/env.py b/flatland/core/env.py index 3c25beea236945b1728959e02ea07f6c0ba7a6ac..32691f507f4cb5586f10b5645cc22ece718edc21 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -84,6 +84,21 @@ class Environment: """ raise NotImplementedError() + def predict(self): + """ + Predictions step. + + Returns predictions for the agents. + The returns are dicts mapping from agent_id strings to values. + + Returns + ------- + predictions : dict + New predictions for each ready agent. + + """ + raise NotImplementedError() + def render(self): """ Perform rendering of the environment. diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index 3cef545c1658e6bfe2a292ee26c3e665ce6a5abc..f85afee4b625e59374c6cce266bf55b21e7fdb84 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -19,7 +19,6 @@ class ObservationBuilder: def __init__(self): self.observation_space = () - pass def _set_env(self, env): self.env = env diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..321ede959e94ced2cde17f99a02796d0d2abe277 --- /dev/null +++ b/flatland/core/env_prediction_builder.py @@ -0,0 +1,45 @@ +""" +PredictionBuilder objects are objects that can be passed to environments designed for customizability. +The PredictionBuilder-derived custom classes implement 2 functions, reset() and get([handle]). +If predictions are not required in every step or not for all agents, then + ++ Reset() is called after each environment reset, to allow for pre-computing relevant data. + ++ Get() is called whenever an step has to be computed, potentially for each agent independently in +case of multi-agent environments. +""" + + +class PredictionBuilder: + """ + PredictionBuilder base class. + """ + + def __init__(self): + pass + + def _set_env(self, env): + self.env = env + + def reset(self): + """ + Called after each environment reset. + """ + raise NotImplementedError() + + def get(self, handle=0): + """ + Called whenever an observation has to be computed for the `env' environment, possibly + for each agent independently (agent id `handle'). + + Parameters + ------- + handle : int (optional) + Handle of the agent for which to compute the observation vector. + + Returns + ------- + function + An prediction structure, specific to the corresponding environment. + """ + raise NotImplementedError() diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py new file mode 100644 index 0000000000000000000000000000000000000000..e57bf7534fa92576978416e72b3cecb851a746b8 --- /dev/null +++ b/flatland/envs/predictions.py @@ -0,0 +1,23 @@ +""" +Collection of environment-specific PredictionBuilder. +""" + +from flatland.core.env_prediction_builder import PredictionBuilder + + +class DummyPredictorForRailEnv(PredictionBuilder): + """ + DummyPredictorForRailEnv object. + + This object returns predictions for agents in the RailEnv environment. + The prediction acts as if no other agent is in the environment and always takes the forward action. + """ + + def __init__(self): + pass + + def reset(self): + pass + + def get(self, handle=0): + return {} diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index a23b6ac71651e8e424bb90a23cbcfd472ce89a12..6ac31878ff4c793d02ab78ec31e71285bd41e84d 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -51,7 +51,9 @@ class RailEnv(Environment): height, rail_generator=random_rail_generator(), number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2)): + obs_builder_object=TreeObsForRailEnv(max_depth=2), + prediction_builder_object=None + ): """ Environment init. @@ -94,6 +96,11 @@ class RailEnv(Environment): self.obs_builder = obs_builder_object self.obs_builder._set_env(self) + self.prediction_builder = prediction_builder_object + if self.prediction_builder: + self.prediction_builder._set_env(self) + + self.action_space = [1] self.observation_space = self.obs_builder.observation_space # updated on resets? @@ -296,6 +303,12 @@ class RailEnv(Environment): self.actions = [0] * self.get_num_agents() return self._get_observations(), self.rewards_dict, self.dones, {} + def predict(self): + if not self.prediction_builder: + return {} + return self.prediction_builder.get() + + def check_action(self, agent, action): transition_isValid = None possible_transitions = self.rail.get_transitions((*agent.position, agent.direction)) @@ -332,6 +345,11 @@ class RailEnv(Environment): self.obs_dict[iAgent] = self.obs_builder.get(iAgent) return self.obs_dict + def _get_predictions(self): + if not self.prediction_builder: + return {} + return {} + def render(self): # TODO: pass diff --git a/tests/test_env_prediction_builder.py b/tests/test_env_prediction_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..a43f7883549474b34377c176d293910be4e0a635 --- /dev/null +++ b/tests/test_env_prediction_builder.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import numpy as np + +from flatland.envs.observations import GlobalObsForRailEnv +from flatland.core.transition_map import GridTransitionMap, Grid4Transitions +from flatland.envs.predictions import DummyPredictorForRailEnv +from flatland.envs.rail_env import RailEnv +from flatland.envs.generators import rail_from_GridTransitionMap_generator + +"""Tests for `flatland` package.""" + + +def test_predictions(): + # We instantiate a very simple rail network on a 7x10 grid: + # | + # | + # | + # _ _ _ /_\ _ _ _ _ _ _ + # \ / + # | + # | + # | + + cells = [int('0000000000000000', 2), # empty cell - Case 0 + int('1000000000100000', 2), # Case 1 - straight + int('1001001000100000', 2), # Case 2 - simple switch + int('1000010000100001', 2), # Case 3 - diamond drossing + int('1001011000100001', 2), # Case 4 - single slip switch + int('1100110000110011', 2), # Case 5 - double slip switch + int('0101001000000010', 2), # Case 6 - symmetrical switch + int('0010000000000000', 2)] # Case 7 - dead end + + transitions = Grid4Transitions([]) + empty = cells[0] + + dead_end_from_south = cells[7] + dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) + dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180) + dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) + + vertical_straight = cells[1] + horizontal_straight = transitions.rotate_transition(vertical_straight, 90) + + double_switch_south_horizontal_straight = horizontal_straight + cells[6] + double_switch_north_horizontal_straight = transitions.rotate_transition( + double_switch_south_horizontal_straight, 180) + + rail_map = np.array( + [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + + [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + + [[dead_end_from_east] + [horizontal_straight] * 2 + + [double_switch_north_horizontal_straight] + + [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] + + [horizontal_straight] * 2 + [dead_end_from_west]] + + [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + + [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) + + rail = GridTransitionMap(width=rail_map.shape[1], + height=rail_map.shape[0], transitions=transitions) + rail.grid = rail_map + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_GridTransitionMap_generator(rail), + number_of_agents=1, + obs_builder_object=GlobalObsForRailEnv(), + prediction_builder_object=DummyPredictorForRailEnv() + ) + + env.reset() + + predictions = env.predict() + + +def main(): + test_predictions() + + +if __name__ == "__main__": + main()