diff --git a/flatland/core/grid/grid4.py b/flatland/core/grid/grid4.py index 2d7284dc74cf2411ee2f0466d0d653bffb2132dc..819fc8da9102b688138cb06f59833e718188a411 100644 --- a/flatland/core/grid/grid4.py +++ b/flatland/core/grid/grid4.py @@ -1,5 +1,5 @@ from enum import IntEnum -from typing import Type +from typing import Type, List import numpy as np @@ -238,5 +238,6 @@ class Grid4Transitions(Transitions): cell_transition &= cell_transition & (~self.maskDeadEnds) & 0xffff return cell_transition - def get_entry_directions(self, cell_transition): + @staticmethod + def get_entry_directions(cell_transition) -> List[int]: return [(cell_transition >> ((3 - orientation) * 4)) & 15 > 0 for orientation in range(4)] diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 995e7d67ca87f306be7c87dbba685f46cde240d2..cf77cb580f2af7e4657140517e58825902be47e8 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -14,7 +14,7 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transitions import Transitions from flatland.utils.ordered_set import OrderedSet - +# TODO are these general classes or for grid4 only? class TransitionMap: """ Base TransitionMap class. diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 901f318c0add611c4d0ebb0846c7a4c0c30f39dd..887f97aae7c74af8d986454f0c50dcaace59f0dc 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -12,7 +12,7 @@ import numpy as np from flatland.core.env import Environment from flatland.core.env_observation_builder import ObservationBuilder -from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions from flatland.core.grid.grid4_utils import get_new_position from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent @@ -593,6 +593,9 @@ class RailEnv(Environment): self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents()))) return self.obs_dict + def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]: + return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row,col)) + def get_full_state_msg(self): grid_data = self.rail.grid.tolist() agent_static_data = [agent.to_list() for agent in self.agents_static] diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index cc516de2174d262fdc5c5086b01921c183323031..0ec268dc1d3910967541ba0bcc357c8cc4638ec4 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -896,25 +896,20 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, Function to fix all transition elements in environment """ # Fix all nodes with illegal transition maps - empty_to_fix = [] - rails_to_fix = [] + rails_to_fix = np.zeros(2 * grid_map.height * grid_map.width * 2, dtype='int') + rails_to_fix_cnt = 0 for cell in city_cells: check = grid_map.cell_neighbours_valid(cell, True) if grid_map.grid[cell] == int('1000010000100001', 2): grid_map.fix_transitions(cell) if not check: - if grid_map.grid[cell] == 0: - empty_to_fix.append(cell) - else: - rails_to_fix.append(cell) - - # Fix empty cells first to avoid cutting the network - for cell in empty_to_fix: - grid_map.fix_transitions(cell) + rails_to_fix[2 * rails_to_fix_cnt] = cell[0] + rails_to_fix[2 * rails_to_fix_cnt + 1] = cell[1] + rails_to_fix_cnt += 1 # Fix all other cells - for cell in rails_to_fix: - grid_map.fix_transitions(cell) + for cell in range(rails_to_fix_cnt): + grid_map.fix_transitions((rails_to_fix[2 * cell], rails_to_fix[2 * cell + 1])) def _closest_neigh_in_direction(current_node, node_positions): """ diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py index 930cc24fb4a9be817c14f2cca149747ac6ca370b..0913e45959d08230a815c33d98fb6de8eb99d956 100644 --- a/tests/test_flatland_core_transition_map.py +++ b/tests/test_flatland_core_transition_map.py @@ -174,7 +174,7 @@ def test_get_entry_directions(): north_west_turn = transitions.rotate_transition(south_east_turn, 180) def _assert(transition, expected): - actual = transitions.get_entry_directions(transition) + actual = Grid4Transitions.get_entry_directions(transition) assert actual == expected, "Found {}, expected {}.".format(actual, expected) _assert(south_east_turn, [True, False, False, True]) diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index 0114730a2ac1d0df618eea773dfcf1cd7175dee2..0fefd3e212ddb5f084c1e219f4063079e03dabdf 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -6,11 +6,13 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgent from flatland.envs.agent_utils import EnvAgentStatic -from flatland.envs.observations import GlobalObsForRailEnv +from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator +from flatland.utils.simple_rail import make_simple_rail """Tests for `flatland` package.""" @@ -212,3 +214,36 @@ def test_dead_end(): rail_env.reset() rail_env.agents = [EnvAgent(position=(2, 0), direction=0, target=(4, 0), moving=False)] + + +def test_get_entry_directions(): + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + + def _assert(position, expected): + actual = env.get_valid_directions_on_grid(*position) + assert actual == expected, "[{},{}] actual={}, expected={}".format(*position, actual, expected) + + # north dead end + _assert((0, 3), [True, False, False, False]) + + # west dead end + _assert((3, 0), [False, False, False, True]) + + # switch + _assert((3, 3), [False, True, True, True]) + + # horizontal + _assert((3, 2), [False, True, False, True]) + + # vertical + _assert((2, 3), [True, False, True, False]) + + # nowhere + _assert((0, 0), [False, False, False, False])