Commit da777473 authored by u214892's avatar u214892
Browse files

#162 stochasticity tests

parent 3fb4cb97
Pipeline #1938 failed with stages
in 11 minutes and 25 seconds
......@@ -305,6 +305,7 @@ class RailEnv(Environment):
return True
return False
# TODO refactor to decrease length of this method!
def step(self, action_dict_):
self._elapsed_steps += 1
......@@ -344,7 +345,7 @@ class RailEnv(Environment):
action = RailEnvActions.DO_NOTHING
# Check if agent breaks at this step
malfunction = self._agent_malfunction(i_agent, action)
new_malfunction = self._agent_malfunction(i_agent, action)
# Is the agent at the beginning of the cell? Then, it can take an action
# Design choice (Erik+Christian):
......@@ -397,11 +398,11 @@ class RailEnv(Environment):
else:
agent.speed_data['transition_action_on_cellexit'] = action
# if we're broken, nothing else to do
if malfunction:
# if we've just broken in this step, nothing else to do
if new_malfunction:
continue
# The train is broken
# The train was broken before...
if agent.malfunction_data['malfunction'] > 0:
# Last step of malfunction --> Agent starts moving again after getting fixed
......@@ -424,11 +425,9 @@ class RailEnv(Environment):
# If agent.moving, increment the position_fraction by the speed of the agent
# If the new position fraction is >= 1, reset to 0, and perform the stored
# transition_action_on_cellexit if the cell is free.
if agent.moving:
agent.speed_data['position_fraction'] += agent.speed_data['speed']
if agent.speed_data['position_fraction'] >= 1.0:
# Perform stored action to transition to the next cell as soon as cell is free
# Notice that we've already check new_cell_valid and transition valid when we stored the action,
......@@ -441,7 +440,8 @@ class RailEnv(Environment):
cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent(
agent.speed_data['transition_action_on_cellexit'], agent)
if not cell_free == all([cell_free, new_cell_valid, transition_valid]):
warnings.warn("Inconsistent state: cell or transition not valid although checked when we stored transition_action_on_cellexit!")
warnings.warn(
"Inconsistent state: cell or transition not valid although checked when we stored transition_action_on_cellexit!")
if cell_free:
agent.position = new_position
agent.direction = new_direction
......
......@@ -110,3 +110,37 @@ def test_malfunction_process():
# Check that malfunctioning data was standing around
assert total_down_time > 0
def test_malfunction_process_statistically():
"""Tests hat malfunctions are produced by stochastic_data!"""
# Set fixed malfunction duration for this test
stochastic_data = {'prop_malfunction': 1.,
'malfunction_rate': 2,
'min_duration': 3,
'max_duration': 3}
np.random.seed(5)
env = RailEnv(width=20,
height=20,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
seed=0),
schedule_generator=complex_schedule_generator(),
number_of_agents=2,
obs_builder_object=SingleAgentNavigationObs(),
stochastic_data=stochastic_data)
env.reset()
nb_malfunction = 0
for step in range(100):
action_dict = {}
for agent in env.agents:
if agent.malfunction_data['malfunction'] > 0:
nb_malfunction += 1
# We randomly select an action
action_dict[agent.handle] = np.random.randint(4)
env.step(action_dict)
# check that generation of malfunctions works as expected
assert nb_malfunction == 156
......@@ -97,9 +97,23 @@ def test_multi_speed_init():
old_pos[i_agent] = env.agents[i_agent].position
# TODO test malfunction
# TODO test other agent blocking
def test_multispeed_actions_no_malfunction(rendering=True):
@attrs
class Replay(object):
position = attrib()
direction = attrib()
action = attrib(type=RailEnvActions)
malfunction = attrib(default=0, type=int)
@attrs
class TestConfig(object):
replay = attrib(type=List[Replay])
target = attrib()
speed = attrib(type=float)
def test_multispeed_actions_no_malfunction_no_blocking(rendering=True):
"""Test that actions are correctly performed on cell exit for a single agent."""
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
......@@ -112,17 +126,135 @@ def test_multispeed_actions_no_malfunction(rendering=True):
# initialize agents_static
env.reset()
@attrs
class Replay(object):
position = attrib()
direction = attrib()
action = attrib(type=RailEnvActions)
# reset to set agents from agents_static
env.reset(False, False)
if rendering:
renderer = RenderTool(env, gl="PILSVG")
test_config = TestConfig(
replay=[
Replay(
position=(3, 9), # east dead-end
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
action=None
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_LEFT
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.STOP_MOVING
),
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.STOP_MOVING
),
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=None
),
Replay(
position=(5, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.MOVE_FORWARD
),
],
target=(3, 0), # west dead-end
speed=0.5
)
# TODO test penalties!
agentStatic: EnvAgentStatic = env.agents_static[0]
info_dict = {
'action_required': [True]
}
for i, replay in enumerate(test_config.replay):
if i == 0:
# set the initial position
agentStatic.position = replay.position
agentStatic.direction = replay.direction
agentStatic.target = test_config.target
agentStatic.moving = True
agentStatic.speed_data['speed'] = test_config.speed
# reset to set agents from agents_static
env.reset(False, False)
def _assert(actual, expected, msg):
assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected)
agent: EnvAgent = env.agents[0]
_assert(agent.position, replay.position, 'position')
_assert(agent.direction, replay.direction, 'direction')
if replay.action:
assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
_, _, _, info_dict = env.step({0: replay.action})
else:
assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
_, _, _, info_dict = env.step({})
if rendering:
renderer.render_env(show=True, show_observations=True)
@attrs
class TestConfig(object):
replay = attrib(type=List[Replay])
target = attrib()
speed = attrib(type=float)
def test_multispeed_actions_no_malfunction_blocking(rendering=True):
"""The second agent blocks the first because it is slower."""
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
# initialize agents_static
env.reset()
# reset to set agents from agents_static
env.reset(False, False)
......@@ -134,85 +266,156 @@ def test_multispeed_actions_no_malfunction(rendering=True):
TestConfig(
replay=[
Replay(
position=(3, 9), # east dead-end
direction=Grid4TransitionsEnum.EAST,
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(3, 8),
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(3, 7),
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_LEFT
action=None
),
Replay(
position=(3, 6),
position=(3, 5),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(3, 5),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.STOP_MOVING
position=(3, 5),
direction=Grid4TransitionsEnum.WEST,
action=None
)
],
target=(3, 0), # west dead-end
speed=1 / 3),
TestConfig(
replay=[
Replay(
position=(3, 9), # east dead-end
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.STOP_MOVING
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
action=None
),
# blocked although fraction >= 1.0
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
action=None
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None
),
# blocked although fraction >= 1.0
Replay(
position=(5, 6),
direction=Grid4TransitionsEnum.SOUTH,
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None
),
# blocked although fraction >= 1.0
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_LEFT
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=None
),
# not blocked, action required!
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.MOVE_FORWARD
),
],
target=(3, 0), # west dead-end
speed=0.5
)
]
# TODO test penalties!
agentStatic: EnvAgentStatic = env.agents_static[0]
for test_config in test_configs:
info_dict = {
'action_required': [True]
}
for i, replay in enumerate(test_config.replay):
if i == 0:
info_dict = {
'action_required': [True for _ in test_configs]
}
for step in range(len(test_configs[0].replay)):
if step == 0:
for a, test_config in enumerate(test_configs):
agentStatic: EnvAgentStatic = env.agents_static[a]
replay = test_config.replay[0]
# set the initial position
agentStatic.position = replay.position
agentStatic.direction = replay.direction
......@@ -220,24 +423,177 @@ def test_multispeed_actions_no_malfunction(rendering=True):
agentStatic.moving = True
agentStatic.speed_data['speed'] = test_config.speed
# reset to set agents from agents_static
env.reset(False, False)
# reset to set agents from agents_static
env.reset(False, False)
def _assert(actual, expected, msg):
assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected)
def _assert(a, actual, expected, msg):
assert actual == expected, "[{}] {} {}: actual={}, expected={}".format(step, a, msg, actual, expected)
agent: EnvAgent = env.agents[0]
action_dict = {}
for a, test_config in enumerate(test_configs):
agent: EnvAgent = env.agents[a]
replay = test_config.replay[step]
_assert(a, agent.position, replay.position, 'position')
_assert(a, agent.direction, replay.direction, 'direction')
_assert(agent.position, replay.position, 'position')
_assert(agent.direction, replay.direction, 'direction')
if replay.action:
assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
_, _, _, info_dict = env.step({0: replay.action})
if replay.action:
assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format(step, a, True)
action_dict[a] = replay.action
else:
assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
_, _, _, info_dict = env.step({})
assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format(step, a, False)
_, _, _, info_dict = env.step(action_dict)
if rendering:
renderer.render_env(show=True, show_observations=True)
if rendering:
renderer.render_env(show=True, show_observations=True)
def test_multispeed_actions_malfunction_no_blocking(rendering=True):
"""Test on a single agent whether action on cell exit work correctly despite malfunction."""
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
# initialize agents_static
env.reset()
# reset to set agents from agents_static
env.reset(False, False)
if rendering:
renderer = RenderTool(env, gl="PILSVG")
test_config = TestConfig(
replay=[
Replay(
position=(3, 9), # east dead-end
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
action=None
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD
),
# add additional step in the cell
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None,
malfunction=2 # recovers in two steps from now!
),
# agent recovers in this step
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=2 # recovers in two steps from now!
),
# agent recovers in this step; since we're at the beginning, we provide a different action although we're broken!
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_LEFT,
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=None
),
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.STOP_MOVING
),
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.STOP_MOVING
),
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.MOVE_FORWARD
),
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=None
),
Replay(
position=(5, 6),
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.MOVE_FORWARD
),
],
target=(3, 0), # west dead-end
speed=0.5
)
# TODO test penalties!
agentStatic: EnvAgentStatic = env.agents_static[0]
info_dict = {
'action_required': [True]
}
for i, replay in enumerate(test_config.replay):
if i == 0:
# set the initial position