diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py index 467ad67a117b992bb6ab02b6f70bc83da3557081..dc4ffe82b6153573ba8528a8017b205dbd8da20d 100644 --- a/flatland/core/transitions.py +++ b/flatland/core/transitions.py @@ -438,14 +438,15 @@ class RailEnvTransitions(GridTransitions): transitions available as a function of the agent's orientation (north, east, south, west) """ - transition_list = [int('0000000000000000', 2), - int('1000000000100000', 2), - int('1001001000100000', 2), - int('1000010000100001', 2), - int('1001011000100001', 2), - int('1100110000110011', 2), - int('0101001000000010', 2), - int('0000000000100000', 2)] + + transition_list = [int('0000000000000000', 2), # empty cell - Case 0 + int('1000000000100000', 2), # Case 1 - straight + int('1001001000000000', 2), # Case 2 - simple switch + int('1000010000100001', 2), # Case 3 - diamond drossing + int('1001011000100001', 2), # Case 4 - single slip switch + int('1100110000110011', 2), # Case 5 - double slip switch + int('0101001000000010', 2), # Case 6 - symmetrical switch + int('0010000000000000', 2)] # Case 7 - dead end def __init__(self): super(RailEnvTransitions, self).__init__( diff --git a/tests/test_environments.py b/tests/test_environments.py index ae4d5249e143769f41e607ccd477737ef0e440b9..e95c398c15994ce9ff1481839a481377b9a5d262 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -1,11 +1,51 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +from flatland.core.env import RailEnv +from flatland.core.transitions import GridTransitions +import numpy as np + """Tests for `flatland` package.""" -def test_base_environment(): - """Test example Transition.""" - a = True - b = True - assert a == b + +def test_rail_environment(): + cells = [int('0000000000000000', 2), # empty cell - Case 0 + int('1000000000100000', 2), # Case 1 - straight + int('1001001000000000', 2), # Case 2 - simple switch + int('1000010000100001', 2), # Case 3 - diamond drossing + int('1001011000100001', 2), # Case 4 - single slip switch + int('1100110000110011', 2), # Case 5 - double slip switch + int('0101001000000010', 2), # Case 6 - symmetrical switch + int('0010000000000000', 2)] # Case 7 - dead end + + # We instantiate the following map on a 3x3 grid + # _ _ + # / \/ \ + # | | | + # \_/\_/ + + transitions = GridTransitions([], False) + vertical_line = cells[1] + south_symmetrical_switch = cells[6] + north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180) + south_east_turn = int('0100000000100000', 2) # Simple turn not in the base transitions ? + south_west_turn = transitions.rotate_transition(south_east_turn, 90) + north_east_turn = transitions.rotate_transition(south_east_turn, 270) + north_west_turn = transitions.rotate_transition(south_east_turn, 180) + + rail_map = np.array([[south_east_turn, south_symmetrical_switch, south_west_turn], + [vertical_line, vertical_line, vertical_line], + [north_east_turn, north_symmetrical_switch, north_west_turn]], + dtype=np.uint16) + + rail_env = RailEnv(rail_map, number_of_agents=1) + + # Check that trains are always initialized at a consistent position / direction. + # They should always be able to go somewhere. + for _ in range(1000): + obs = rail_env.reset() + assert(transitions.get_transitions_from_orientation( + rail_map[rail_env.agents_position[0]], + rail_env.agents_direction[0]) != (0, 0, 0, 0)) +