Skip to content
Snippets Groups Projects
Commit bfcd9ab5 authored by gmollard's avatar gmollard
Browse files

test for dead ends

parent d70a3855
No related branches found
No related tags found
No related merge requests found
......@@ -34,7 +34,6 @@ def test_rail_environment_single_agent():
# Simple turn not in the base transitions ?
south_east_turn = int('0100000000000010', 2)
south_west_turn = transitions.rotate_transition(south_east_turn, 90)
# print(bytes(south_west_turn))
north_east_turn = transitions.rotate_transition(south_east_turn, 270)
north_west_turn = transitions.rotate_transition(south_east_turn, 180)
......@@ -77,7 +76,7 @@ def test_rail_environment_single_agent():
valid_active_actions_done += 1
# After 6 movements on this railway network, the train should be back
# to its original position.
# to its original height on the map.
assert(initial_pos[0] == rail_env.agents_position[0][0])
# We check that the train always attains its target after some time
......@@ -92,3 +91,97 @@ def test_rail_environment_single_agent():
_, _, dones, _ = rail_env.step({0: action})
done = dones['__all__']
def test_dead_end():
transitions = Grid4Transitions([])
straight_vertical = int('1000000000100000', 2) # Case 1 - straight
straight_horizontal = transitions.rotate_transition(straight_vertical,
90)
dead_end_from_south = int('0010000000000000', 2) # Case 7 - dead end
# We instantiate the following railway
# O->-- where > is the train and O the target. After 6 steps,
# the train should be done.
rail_map = np.array(
[[transitions.rotate_transition(dead_end_from_south, 270)] +
[straight_horizontal] * 3 +
[transitions.rotate_transition(dead_end_from_south, 90)]],
dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0],
transitions=transitions)
rail.grid = rail_map
rail_env = RailEnv(rail, number_of_agents=1)
def check_consistency(rail_env):
# We run step to check that trains do not move anymore
# after being done.
for i in range(7):
prev_pos = rail_env.agents_position[0]
# The train cannot turn, so we check that when it tries,
# it stays where it is.
_ = rail_env.step({0: 1})
_ = rail_env.step({0: 3})
assert (rail_env.agents_position[0] == prev_pos)
_, _, dones, _ = rail_env.step({0: 2})
if i < 5:
assert (not dones[0] and not dones['__all__'])
else:
assert (dones[0] and dones['__all__'])
# We try the configuration in the 4 directions:
rail_env.reset()
rail_env.agents_target[0] = [0, 0]
rail_env.agents_position[0] = [0, 2]
rail_env.agents_direction[0] = 1
check_consistency(rail_env)
rail_env.reset()
rail_env.agents_target[0] = [0, 4]
rail_env.agents_position[0] = [0, 2]
rail_env.agents_direction[0] = 3
check_consistency(rail_env)
# In the vertical configuration:
rail_map = np.array(
[[dead_end_from_south]] + [[straight_vertical]] * 3 +
[[transitions.rotate_transition(dead_end_from_south, 180)]],
dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0],
transitions=transitions)
rail.grid = rail_map
rail_env = RailEnv(rail, number_of_agents=1)
rail_env.reset()
rail_env.agents_target[0] = [0, 0]
rail_env.agents_position[0] = [2, 0]
rail_env.agents_direction[0] = 2
check_consistency(rail_env)
rail_env.reset()
rail_env.agents_target[0] = [4, 0]
rail_env.agents_position[0] = [2, 0]
rail_env.agents_direction[0] = 0
check_consistency(rail_env)
test_dead_end()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment