Forked from
Flatland / Flatland
973 commits behind the upstream repository.
test_flatland_core_transition_map.py 7.35 KiB
from flatland.core.grid.grid4 import Grid4Transitions, Grid4TransitionsEnum
from flatland.core.grid.grid8 import Grid8Transitions, Grid8TransitionsEnum
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail, make_simple_rail_unconnected
def test_grid4_get_transitions():
grid4_map = GridTransitionMap(2, 2, Grid4Transitions([]))
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
assert grid4_map.get_full_transitions(0, 0) == 0
grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 1)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (1, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
assert grid4_map.get_full_transitions(0, 0) == pow(2, 15) # the most significant bit is on
grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.WEST, 1)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (1, 0, 0, 1)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
# the most significant and the fourth most significant bits are on
assert grid4_map.get_full_transitions(0, 0) == pow(2, 15) + pow(2, 12)
grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (0, 0, 0, 1)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
# the fourth most significant bits are on
assert grid4_map.get_full_transitions(0, 0) == pow(2, 12)
def test_grid8_set_transitions():
grid8_map = GridTransitionMap(2, 2, Grid8Transitions([]))
assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (0, 0, 0, 0, 0, 0, 0, 0)
grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 1)
assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (1, 0, 0, 0, 0, 0, 0, 0)
grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 0)
assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (0, 0, 0, 0, 0, 0, 0, 0)
def check_path(env, rail, position, direction, target, expected, rendering=False):
agent = env.agents_static[0]
agent.position = position # south dead-end
agent.direction = direction # north
agent.target = target # east dead-end
agent.moving = True
# reset to set agents from agents_static
# env.reset(False, False)
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.render_env(show=True, show_observations=False)
input("Continue?")
assert rail.check_path_exists(agent.position, agent.direction, agent.target) == expected
def test_path_exists(rendering=False):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
# reset to initialize agents_static
env.reset()
check_path(
env,
rail,
(5, 6), # north of south dead-end
0, # north
(3, 9), # east dead-end
True
)
check_path(
env,
rail,
(6, 6), # south dead-end
2, # south
(3, 9), # east dead-end
True
)
check_path(
env,
rail,
(3, 0), # east dead-end
3, # west
(0, 3), # north dead-end
True
)
check_path(
env,
rail,
(5, 6), # east dead-end
0, # west
(1, 3), # north dead-end
True)
check_path(
env,
rail,
(1, 3), # east dead-end
2, # south
(3, 3), # north dead-end
True
)
check_path(
env,
rail,
(1, 3), # east dead-end
0, # north
(3, 3), # north dead-end
True
)
def test_path_not_exists(rendering=False):
rail, rail_map = make_simple_rail_unconnected()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
# reset to initialize agents_static
env.reset()
check_path(
env,
rail,
(5, 6), # south dead-end
0, # north
(0, 3), # north dead-end
False
)
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.render_env(show=True, show_observations=False)
input("Continue?")
def test_get_entry_directions():
transitions = RailEnvTransitions()
cells = transitions.transition_list
vertical_line = cells[1]
south_symmetrical_switch = cells[6]
north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180)
south_east_turn = int('0100000000000010', 2)
south_west_turn = transitions.rotate_transition(south_east_turn, 90)
north_east_turn = transitions.rotate_transition(south_east_turn, 270)
north_west_turn = transitions.rotate_transition(south_east_turn, 180)
def _assert(transition, expected):
actual = Grid4Transitions.get_entry_directions(transition)
assert actual == expected, "Found {}, expected {}.".format(actual, expected)
_assert(south_east_turn, [True, False, False, True])
_assert(south_west_turn, [True, True, False, False])
_assert(north_east_turn, [False, False, True, True])
_assert(north_west_turn, [False, True, True, False])
_assert(vertical_line, [True, False, True, False])
_assert(south_symmetrical_switch, [True, True, False, True])
_assert(north_symmetrical_switch, [False, True, True, True])