Commit 4f997269 authored by u214892's avatar u214892
Browse files

helper methods for valid directions

parent 96b8f258
Pipeline #2185 passed with stages
in 63 minutes and 29 seconds
from enum import IntEnum
from typing import Type
from typing import Type, List
import numpy as np
......@@ -238,5 +238,6 @@ class Grid4Transitions(Transitions):
cell_transition &= cell_transition & (~self.maskDeadEnds) & 0xffff
return cell_transition
def get_entry_directions(self, cell_transition):
@staticmethod
def get_entry_directions(cell_transition) -> List[int]:
return [(cell_transition >> ((3 - orientation) * 4)) & 15 > 0 for orientation in range(4)]
......@@ -14,7 +14,7 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transitions import Transitions
from flatland.utils.ordered_set import OrderedSet
# TODO are these general classes or for grid4 only?
class TransitionMap:
"""
Base TransitionMap class.
......
......@@ -12,7 +12,7 @@ import numpy as np
from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
......@@ -592,6 +592,9 @@ class RailEnv(Environment):
self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
return self.obs_dict
def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]:
return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row,col))
def get_full_state_msg(self):
grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static]
......
......@@ -174,7 +174,7 @@ def test_get_entry_directions():
north_west_turn = transitions.rotate_transition(south_east_turn, 180)
def _assert(transition, expected):
actual = transitions.get_entry_directions(transition)
actual = Grid4Transitions.get_entry_directions(transition)
assert actual == expected, "Found {}, expected {}.".format(actual, expected)
_assert(south_east_turn, [True, False, False, True])
......
......@@ -6,11 +6,13 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.agent_utils import EnvAgentStatic
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator
from flatland.utils.simple_rail import make_simple_rail
"""Tests for `flatland` package."""
......@@ -212,3 +214,36 @@ def test_dead_end():
rail_env.reset()
rail_env.agents = [EnvAgent(position=(2, 0), direction=0, target=(4, 0), moving=False)]
def test_get_entry_directions():
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
def _assert(position, expected):
actual = env.get_valid_directions_on_grid(*position)
assert actual == expected, "[{},{}] actual={}, expected={}".format(*position, actual, expected)
# north dead end
_assert((0, 3), [True, False, False, False])
# west dead end
_assert((3, 0), [False, False, False, True])
# switch
_assert((3, 3), [False, True, True, True])
# horizontal
_assert((3, 2), [False, True, False, True])
# vertical
_assert((2, 3), [True, False, True, False])
# nowhere
_assert((0, 0), [False, False, False, False])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment