From 4f99726971c27388055aa9481850f4a4a8177727 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Thu, 26 Sep 2019 16:56:00 +0200 Subject: [PATCH] helper methods for valid directions --- flatland/core/grid/grid4.py | 5 +-- flatland/core/transition_map.py | 2 +- flatland/envs/rail_env.py | 5 ++- tests/test_flatland_core_transition_map.py | 2 +- tests/test_flatland_envs_rail_env.py | 37 +++++++++++++++++++++- 5 files changed, 45 insertions(+), 6 deletions(-) diff --git a/flatland/core/grid/grid4.py b/flatland/core/grid/grid4.py index 2d7284dc..819fc8da 100644 --- a/flatland/core/grid/grid4.py +++ b/flatland/core/grid/grid4.py @@ -1,5 +1,5 @@ from enum import IntEnum -from typing import Type +from typing import Type, List import numpy as np @@ -238,5 +238,6 @@ class Grid4Transitions(Transitions): cell_transition &= cell_transition & (~self.maskDeadEnds) & 0xffff return cell_transition - def get_entry_directions(self, cell_transition): + @staticmethod + def get_entry_directions(cell_transition) -> List[int]: return [(cell_transition >> ((3 - orientation) * 4)) & 15 > 0 for orientation in range(4)] diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 07678add..9db7f3c7 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -14,7 +14,7 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transitions import Transitions from flatland.utils.ordered_set import OrderedSet - +# TODO are these general classes or for grid4 only? class TransitionMap: """ Base TransitionMap class. diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index b5bc44f2..6e021183 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -12,7 +12,7 @@ import numpy as np from flatland.core.env import Environment from flatland.core.env_observation_builder import ObservationBuilder -from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions from flatland.core.grid.grid4_utils import get_new_position from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent @@ -592,6 +592,9 @@ class RailEnv(Environment): self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents()))) return self.obs_dict + def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]: + return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row,col)) + def get_full_state_msg(self): grid_data = self.rail.grid.tolist() agent_static_data = [agent.to_list() for agent in self.agents_static] diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py index 930cc24f..0913e459 100644 --- a/tests/test_flatland_core_transition_map.py +++ b/tests/test_flatland_core_transition_map.py @@ -174,7 +174,7 @@ def test_get_entry_directions(): north_west_turn = transitions.rotate_transition(south_east_turn, 180) def _assert(transition, expected): - actual = transitions.get_entry_directions(transition) + actual = Grid4Transitions.get_entry_directions(transition) assert actual == expected, "Found {}, expected {}.".format(actual, expected) _assert(south_east_turn, [True, False, False, True]) diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index 0114730a..0fefd3e2 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -6,11 +6,13 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgent from flatland.envs.agent_utils import EnvAgentStatic -from flatland.envs.observations import GlobalObsForRailEnv +from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator +from flatland.utils.simple_rail import make_simple_rail """Tests for `flatland` package.""" @@ -212,3 +214,36 @@ def test_dead_end(): rail_env.reset() rail_env.agents = [EnvAgent(position=(2, 0), direction=0, target=(4, 0), moving=False)] + + +def test_get_entry_directions(): + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + + def _assert(position, expected): + actual = env.get_valid_directions_on_grid(*position) + assert actual == expected, "[{},{}] actual={}, expected={}".format(*position, actual, expected) + + # north dead end + _assert((0, 3), [True, False, False, False]) + + # west dead end + _assert((3, 0), [False, False, False, True]) + + # switch + _assert((3, 3), [False, True, True, True]) + + # horizontal + _assert((3, 2), [False, True, False, True]) + + # vertical + _assert((2, 3), [True, False, True, False]) + + # nowhere + _assert((0, 0), [False, False, False, False]) -- GitLab