From da777473d58c0baaaeb17e3f9303e190cbadf61e Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Thu, 5 Sep 2019 11:47:27 +0200 Subject: [PATCH] #162 stochasticity tests --- flatland/envs/rail_env.py | 14 +- tests/test_flatland_malfunction.py | 34 +++ tests/test_multi_speed.py | 464 +++++++++++++++++++++++++---- 3 files changed, 451 insertions(+), 61 deletions(-) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 0ca62f93..cc115c72 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -305,6 +305,7 @@ class RailEnv(Environment): return True return False + # TODO refactor to decrease length of this method! def step(self, action_dict_): self._elapsed_steps += 1 @@ -344,7 +345,7 @@ class RailEnv(Environment): action = RailEnvActions.DO_NOTHING # Check if agent breaks at this step - malfunction = self._agent_malfunction(i_agent, action) + new_malfunction = self._agent_malfunction(i_agent, action) # Is the agent at the beginning of the cell? Then, it can take an action # Design choice (Erik+Christian): @@ -397,11 +398,11 @@ class RailEnv(Environment): else: agent.speed_data['transition_action_on_cellexit'] = action - # if we're broken, nothing else to do - if malfunction: + # if we've just broken in this step, nothing else to do + if new_malfunction: continue - # The train is broken + # The train was broken before... if agent.malfunction_data['malfunction'] > 0: # Last step of malfunction --> Agent starts moving again after getting fixed @@ -424,11 +425,9 @@ class RailEnv(Environment): # If agent.moving, increment the position_fraction by the speed of the agent # If the new position fraction is >= 1, reset to 0, and perform the stored # transition_action_on_cellexit if the cell is free. - if agent.moving: agent.speed_data['position_fraction'] += agent.speed_data['speed'] - if agent.speed_data['position_fraction'] >= 1.0: # Perform stored action to transition to the next cell as soon as cell is free # Notice that we've already check new_cell_valid and transition valid when we stored the action, @@ -441,7 +440,8 @@ class RailEnv(Environment): cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent( agent.speed_data['transition_action_on_cellexit'], agent) if not cell_free == all([cell_free, new_cell_valid, transition_valid]): - warnings.warn("Inconsistent state: cell or transition not valid although checked when we stored transition_action_on_cellexit!") + warnings.warn( + "Inconsistent state: cell or transition not valid although checked when we stored transition_action_on_cellexit!") if cell_free: agent.position = new_position agent.direction = new_direction diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index eaf782df..e60386c9 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -110,3 +110,37 @@ def test_malfunction_process(): # Check that malfunctioning data was standing around assert total_down_time > 0 + + +def test_malfunction_process_statistically(): + """Tests hat malfunctions are produced by stochastic_data!""" + # Set fixed malfunction duration for this test + stochastic_data = {'prop_malfunction': 1., + 'malfunction_rate': 2, + 'min_duration': 3, + 'max_duration': 3} + np.random.seed(5) + + env = RailEnv(width=20, + height=20, + rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, + seed=0), + schedule_generator=complex_schedule_generator(), + number_of_agents=2, + obs_builder_object=SingleAgentNavigationObs(), + stochastic_data=stochastic_data) + + env.reset() + nb_malfunction = 0 + for step in range(100): + action_dict = {} + for agent in env.agents: + if agent.malfunction_data['malfunction'] > 0: + nb_malfunction += 1 + # We randomly select an action + action_dict[agent.handle] = np.random.randint(4) + + env.step(action_dict) + + # check that generation of malfunctions works as expected + assert nb_malfunction == 156 diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index b8b1afaf..86edc08c 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -97,9 +97,23 @@ def test_multi_speed_init(): old_pos[i_agent] = env.agents[i_agent].position -# TODO test malfunction -# TODO test other agent blocking -def test_multispeed_actions_no_malfunction(rendering=True): +@attrs +class Replay(object): + position = attrib() + direction = attrib() + action = attrib(type=RailEnvActions) + malfunction = attrib(default=0, type=int) + + +@attrs +class TestConfig(object): + replay = attrib(type=List[Replay]) + target = attrib() + speed = attrib(type=float) + + +def test_multispeed_actions_no_malfunction_no_blocking(rendering=True): + """Test that actions are correctly performed on cell exit for a single agent.""" rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], @@ -112,17 +126,135 @@ def test_multispeed_actions_no_malfunction(rendering=True): # initialize agents_static env.reset() - @attrs - class Replay(object): - position = attrib() - direction = attrib() - action = attrib(type=RailEnvActions) + # reset to set agents from agents_static + env.reset(False, False) + + if rendering: + renderer = RenderTool(env, gl="PILSVG") + + test_config = TestConfig( + replay=[ + Replay( + position=(3, 9), # east dead-end + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD + ), + Replay( + position=(3, 9), + direction=Grid4TransitionsEnum.EAST, + action=None + ), + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD + ), + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD + ), + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_LEFT + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + Replay( + position=(4, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=RailEnvActions.STOP_MOVING + ), + Replay( + position=(4, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=RailEnvActions.STOP_MOVING + ), + Replay( + position=(4, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=RailEnvActions.MOVE_FORWARD + ), + Replay( + position=(4, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=None + ), + Replay( + position=(5, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=RailEnvActions.MOVE_FORWARD + ), + + ], + target=(3, 0), # west dead-end + speed=0.5 + ) + + # TODO test penalties! + agentStatic: EnvAgentStatic = env.agents_static[0] + info_dict = { + 'action_required': [True] + } + for i, replay in enumerate(test_config.replay): + if i == 0: + # set the initial position + agentStatic.position = replay.position + agentStatic.direction = replay.direction + agentStatic.target = test_config.target + agentStatic.moving = True + agentStatic.speed_data['speed'] = test_config.speed + + # reset to set agents from agents_static + env.reset(False, False) + + def _assert(actual, expected, msg): + assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected) + + agent: EnvAgent = env.agents[0] + + _assert(agent.position, replay.position, 'position') + _assert(agent.direction, replay.direction, 'direction') + + if replay.action: + assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) + _, _, _, info_dict = env.step({0: replay.action}) + + else: + assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False) + _, _, _, info_dict = env.step({}) + + if rendering: + renderer.render_env(show=True, show_observations=True) - @attrs - class TestConfig(object): - replay = attrib(type=List[Replay]) - target = attrib() - speed = attrib(type=float) + +def test_multispeed_actions_no_malfunction_blocking(rendering=True): + """The second agent blocks the first because it is slower.""" + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=2, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + + # initialize agents_static + env.reset() # reset to set agents from agents_static env.reset(False, False) @@ -134,85 +266,156 @@ def test_multispeed_actions_no_malfunction(rendering=True): TestConfig( replay=[ Replay( - position=(3, 9), # east dead-end - direction=Grid4TransitionsEnum.EAST, + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD ), Replay( - position=(3, 9), - direction=Grid4TransitionsEnum.EAST, + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, action=None ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, + action=None + ), + + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD ), Replay( - position=(3, 8), + position=(3, 7), direction=Grid4TransitionsEnum.WEST, action=None ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, + action=None + ), + + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD ), Replay( - position=(3, 7), + position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=None ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_LEFT + action=None ), + Replay( - position=(3, 6), + position=(3, 5), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD + ), + Replay( + position=(3, 5), direction=Grid4TransitionsEnum.WEST, action=None ), Replay( - position=(4, 6), - direction=Grid4TransitionsEnum.SOUTH, - action=RailEnvActions.STOP_MOVING + position=(3, 5), + direction=Grid4TransitionsEnum.WEST, + action=None + ) + ], + target=(3, 0), # west dead-end + speed=1 / 3), + TestConfig( + replay=[ + Replay( + position=(3, 9), # east dead-end + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD ), Replay( - position=(4, 6), - direction=Grid4TransitionsEnum.SOUTH, - action=RailEnvActions.STOP_MOVING + position=(3, 9), + direction=Grid4TransitionsEnum.EAST, + action=None ), + # blocked although fraction >= 1.0 Replay( - position=(4, 6), - direction=Grid4TransitionsEnum.SOUTH, + position=(3, 9), + direction=Grid4TransitionsEnum.EAST, + action=None + ), + + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD ), Replay( - position=(4, 6), - direction=Grid4TransitionsEnum.SOUTH, + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, action=None ), + # blocked although fraction >= 1.0 Replay( - position=(5, 6), - direction=Grid4TransitionsEnum.SOUTH, + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD ), + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + # blocked although fraction >= 1.0 + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_LEFT + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + # not blocked, action required! + Replay( + position=(4, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=RailEnvActions.MOVE_FORWARD + ), ], target=(3, 0), # west dead-end speed=0.5 ) + ] # TODO test penalties! - agentStatic: EnvAgentStatic = env.agents_static[0] - for test_config in test_configs: - info_dict = { - 'action_required': [True] - } - for i, replay in enumerate(test_config.replay): - if i == 0: + info_dict = { + 'action_required': [True for _ in test_configs] + } + for step in range(len(test_configs[0].replay)): + if step == 0: + for a, test_config in enumerate(test_configs): + agentStatic: EnvAgentStatic = env.agents_static[a] + replay = test_config.replay[0] # set the initial position agentStatic.position = replay.position agentStatic.direction = replay.direction @@ -220,24 +423,177 @@ def test_multispeed_actions_no_malfunction(rendering=True): agentStatic.moving = True agentStatic.speed_data['speed'] = test_config.speed - # reset to set agents from agents_static - env.reset(False, False) + # reset to set agents from agents_static + env.reset(False, False) - def _assert(actual, expected, msg): - assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected) + def _assert(a, actual, expected, msg): + assert actual == expected, "[{}] {} {}: actual={}, expected={}".format(step, a, msg, actual, expected) - agent: EnvAgent = env.agents[0] + action_dict = {} + + for a, test_config in enumerate(test_configs): + agent: EnvAgent = env.agents[a] + replay = test_config.replay[step] + + _assert(a, agent.position, replay.position, 'position') + _assert(a, agent.direction, replay.direction, 'direction') - _assert(agent.position, replay.position, 'position') - _assert(agent.direction, replay.direction, 'direction') - if replay.action: - assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) - _, _, _, info_dict = env.step({0: replay.action}) + if replay.action: + assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format(step, a, True) + action_dict[a] = replay.action else: - assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False) - _, _, _, info_dict = env.step({}) + assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format(step, a, False) + _, _, _, info_dict = env.step(action_dict) + + if rendering: + renderer.render_env(show=True, show_observations=True) + - if rendering: - renderer.render_env(show=True, show_observations=True) +def test_multispeed_actions_malfunction_no_blocking(rendering=True): + """Test on a single agent whether action on cell exit work correctly despite malfunction.""" + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + + # initialize agents_static + env.reset() + + # reset to set agents from agents_static + env.reset(False, False) + + if rendering: + renderer = RenderTool(env, gl="PILSVG") + + test_config = TestConfig( + replay=[ + Replay( + position=(3, 9), # east dead-end + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD + ), + Replay( + position=(3, 9), + direction=Grid4TransitionsEnum.EAST, + action=None + ), + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD + ), + # add additional step in the cell + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + action=None, + malfunction=2 # recovers in two steps from now! + ), + # agent recovers in this step + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD + ), + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=2 # recovers in two steps from now! + ), + # agent recovers in this step; since we're at the beginning, we provide a different action although we're broken! + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_LEFT, + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + Replay( + position=(4, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=RailEnvActions.STOP_MOVING + ), + Replay( + position=(4, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=RailEnvActions.STOP_MOVING + ), + Replay( + position=(4, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=RailEnvActions.MOVE_FORWARD + ), + Replay( + position=(4, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=None + ), + Replay( + position=(5, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=RailEnvActions.MOVE_FORWARD + ), + + ], + target=(3, 0), # west dead-end + speed=0.5 + ) + + # TODO test penalties! + agentStatic: EnvAgentStatic = env.agents_static[0] + info_dict = { + 'action_required': [True] + } + for i, replay in enumerate(test_config.replay): + if i == 0: + # set the initial position + agentStatic.position = replay.position + agentStatic.direction = replay.direction + agentStatic.target = test_config.target + agentStatic.moving = True + agentStatic.speed_data['speed'] = test_config.speed + + # reset to set agents from agents_static + env.reset(False, False) + + def _assert(actual, expected, msg): + assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected) + + agent: EnvAgent = env.agents[0] + + _assert(agent.position, replay.position, 'position') + _assert(agent.direction, replay.direction, 'direction') + + if replay.malfunction: + agent.malfunction_data['malfunction'] = 2 + + if replay.action: + assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) + _, _, _, info_dict = env.step({0: replay.action}) + + else: + assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False) + _, _, _, info_dict = env.step({}) + + if rendering: + renderer.render_env(show=True, show_observations=True) -- GitLab