diff --git a/tests/test_environments.py b/tests/test_environments.py index c3329a126684ef9462257f51276f7bae94675b28..66e6bef404939ebefa9596d0661348da899ee16d 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -34,7 +34,6 @@ def test_rail_environment_single_agent(): # Simple turn not in the base transitions ? south_east_turn = int('0100000000000010', 2) south_west_turn = transitions.rotate_transition(south_east_turn, 90) - # print(bytes(south_west_turn)) north_east_turn = transitions.rotate_transition(south_east_turn, 270) north_west_turn = transitions.rotate_transition(south_east_turn, 180) @@ -77,7 +76,7 @@ def test_rail_environment_single_agent(): valid_active_actions_done += 1 # After 6 movements on this railway network, the train should be back - # to its original position. + # to its original height on the map. assert(initial_pos[0] == rail_env.agents_position[0][0]) # We check that the train always attains its target after some time @@ -92,3 +91,97 @@ def test_rail_environment_single_agent(): _, _, dones, _ = rail_env.step({0: action}) done = dones['__all__'] + + +def test_dead_end(): + + transitions = Grid4Transitions([]) + + straight_vertical = int('1000000000100000', 2) # Case 1 - straight + straight_horizontal = transitions.rotate_transition(straight_vertical, + 90) + + dead_end_from_south = int('0010000000000000', 2) # Case 7 - dead end + + # We instantiate the following railway + # O->-- where > is the train and O the target. After 6 steps, + # the train should be done. + + rail_map = np.array( + [[transitions.rotate_transition(dead_end_from_south, 270)] + + [straight_horizontal] * 3 + + [transitions.rotate_transition(dead_end_from_south, 90)]], + dtype=np.uint16) + + rail = GridTransitionMap(width=rail_map.shape[1], + height=rail_map.shape[0], + transitions=transitions) + + rail.grid = rail_map + rail_env = RailEnv(rail, number_of_agents=1) + + def check_consistency(rail_env): + # We run step to check that trains do not move anymore + # after being done. + for i in range(7): + prev_pos = rail_env.agents_position[0] + + # The train cannot turn, so we check that when it tries, + # it stays where it is. + _ = rail_env.step({0: 1}) + _ = rail_env.step({0: 3}) + assert (rail_env.agents_position[0] == prev_pos) + + _, _, dones, _ = rail_env.step({0: 2}) + + if i < 5: + assert (not dones[0] and not dones['__all__']) + else: + assert (dones[0] and dones['__all__']) + + # We try the configuration in the 4 directions: + rail_env.reset() + rail_env.agents_target[0] = [0, 0] + rail_env.agents_position[0] = [0, 2] + rail_env.agents_direction[0] = 1 + check_consistency(rail_env) + + rail_env.reset() + rail_env.agents_target[0] = [0, 4] + rail_env.agents_position[0] = [0, 2] + rail_env.agents_direction[0] = 3 + check_consistency(rail_env) + + # In the vertical configuration: + + rail_map = np.array( + [[dead_end_from_south]] + [[straight_vertical]] * 3 + + [[transitions.rotate_transition(dead_end_from_south, 180)]], + dtype=np.uint16) + + rail = GridTransitionMap(width=rail_map.shape[1], + height=rail_map.shape[0], + transitions=transitions) + + rail.grid = rail_map + rail_env = RailEnv(rail, number_of_agents=1) + + rail_env.reset() + rail_env.agents_target[0] = [0, 0] + rail_env.agents_position[0] = [2, 0] + rail_env.agents_direction[0] = 2 + check_consistency(rail_env) + + rail_env.reset() + rail_env.agents_target[0] = [4, 0] + rail_env.agents_position[0] = [2, 0] + rail_env.agents_direction[0] = 0 + check_consistency(rail_env) + + + + + + +test_dead_end() +