From da777473d58c0baaaeb17e3f9303e190cbadf61e Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Thu, 5 Sep 2019 11:47:27 +0200
Subject: [PATCH] #162 stochasticity tests

---
 flatland/envs/rail_env.py          |  14 +-
 tests/test_flatland_malfunction.py |  34 +++
 tests/test_multi_speed.py          | 464 +++++++++++++++++++++++++----
 3 files changed, 451 insertions(+), 61 deletions(-)

diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 0ca62f93..cc115c72 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -305,6 +305,7 @@ class RailEnv(Environment):
             return True
         return False
 
+    # TODO refactor to decrease length of this method!
     def step(self, action_dict_):
         self._elapsed_steps += 1
 
@@ -344,7 +345,7 @@ class RailEnv(Environment):
                 action = RailEnvActions.DO_NOTHING
 
             # Check if agent breaks at this step
-            malfunction = self._agent_malfunction(i_agent, action)
+            new_malfunction = self._agent_malfunction(i_agent, action)
 
             # Is the agent at the beginning of the cell? Then, it can take an action
             # Design choice (Erik+Christian):
@@ -397,11 +398,11 @@ class RailEnv(Environment):
                 else:
                     agent.speed_data['transition_action_on_cellexit'] = action
 
-            # if we're broken, nothing else to do
-            if malfunction:
+            # if we've just broken in this step, nothing else to do
+            if new_malfunction:
                 continue
 
-            # The train is broken
+            # The train was broken before...
             if agent.malfunction_data['malfunction'] > 0:
 
                 # Last step of malfunction --> Agent starts moving again after getting fixed
@@ -424,11 +425,9 @@ class RailEnv(Environment):
             # If agent.moving, increment the position_fraction by the speed of the agent
             # If the new position fraction is >= 1, reset to 0, and perform the stored
             #   transition_action_on_cellexit if the cell is free.
-
             if agent.moving:
 
                 agent.speed_data['position_fraction'] += agent.speed_data['speed']
-
                 if agent.speed_data['position_fraction'] >= 1.0:
                     # Perform stored action to transition to the next cell as soon as cell is free
                     # Notice that we've already check new_cell_valid and transition valid when we stored the action,
@@ -441,7 +440,8 @@ class RailEnv(Environment):
                         cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent(
                             agent.speed_data['transition_action_on_cellexit'], agent)
                         if not cell_free == all([cell_free, new_cell_valid, transition_valid]):
-                            warnings.warn("Inconsistent state: cell or transition not valid although checked when we stored transition_action_on_cellexit!")
+                            warnings.warn(
+                                "Inconsistent state: cell or transition not valid although checked when we stored transition_action_on_cellexit!")
                         if cell_free:
                             agent.position = new_position
                             agent.direction = new_direction
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index eaf782df..e60386c9 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -110,3 +110,37 @@ def test_malfunction_process():
 
     # Check that malfunctioning data was standing around
     assert total_down_time > 0
+
+
+def test_malfunction_process_statistically():
+    """Tests hat malfunctions are produced by stochastic_data!"""
+    # Set fixed malfunction duration for this test
+    stochastic_data = {'prop_malfunction': 1.,
+                       'malfunction_rate': 2,
+                       'min_duration': 3,
+                       'max_duration': 3}
+    np.random.seed(5)
+
+    env = RailEnv(width=20,
+                  height=20,
+                  rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
+                                                        seed=0),
+                  schedule_generator=complex_schedule_generator(),
+                  number_of_agents=2,
+                  obs_builder_object=SingleAgentNavigationObs(),
+                  stochastic_data=stochastic_data)
+
+    env.reset()
+    nb_malfunction = 0
+    for step in range(100):
+        action_dict = {}
+        for agent in env.agents:
+            if agent.malfunction_data['malfunction'] > 0:
+                nb_malfunction += 1
+            # We randomly select an action
+            action_dict[agent.handle] = np.random.randint(4)
+
+        env.step(action_dict)
+
+    # check that generation of malfunctions works as expected
+    assert nb_malfunction == 156
diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py
index b8b1afaf..86edc08c 100644
--- a/tests/test_multi_speed.py
+++ b/tests/test_multi_speed.py
@@ -97,9 +97,23 @@ def test_multi_speed_init():
                 old_pos[i_agent] = env.agents[i_agent].position
 
 
-# TODO test malfunction
-# TODO test other agent blocking
-def test_multispeed_actions_no_malfunction(rendering=True):
+@attrs
+class Replay(object):
+    position = attrib()
+    direction = attrib()
+    action = attrib(type=RailEnvActions)
+    malfunction = attrib(default=0, type=int)
+
+
+@attrs
+class TestConfig(object):
+    replay = attrib(type=List[Replay])
+    target = attrib()
+    speed = attrib(type=float)
+
+
+def test_multispeed_actions_no_malfunction_no_blocking(rendering=True):
+    """Test that actions are correctly performed on cell exit for a single agent."""
     rail, rail_map = make_simple_rail()
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
@@ -112,17 +126,135 @@ def test_multispeed_actions_no_malfunction(rendering=True):
     # initialize agents_static
     env.reset()
 
-    @attrs
-    class Replay(object):
-        position = attrib()
-        direction = attrib()
-        action = attrib(type=RailEnvActions)
+    # reset to set agents from agents_static
+    env.reset(False, False)
+
+    if rendering:
+        renderer = RenderTool(env, gl="PILSVG")
+
+    test_config = TestConfig(
+        replay=[
+            Replay(
+                position=(3, 9),  # east dead-end
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.MOVE_FORWARD
+            ),
+            Replay(
+                position=(3, 9),
+                direction=Grid4TransitionsEnum.EAST,
+                action=None
+            ),
+            Replay(
+                position=(3, 8),
+                direction=Grid4TransitionsEnum.WEST,
+                action=RailEnvActions.MOVE_FORWARD
+            ),
+            Replay(
+                position=(3, 8),
+                direction=Grid4TransitionsEnum.WEST,
+                action=None
+            ),
+            Replay(
+                position=(3, 7),
+                direction=Grid4TransitionsEnum.WEST,
+                action=RailEnvActions.MOVE_FORWARD
+            ),
+            Replay(
+                position=(3, 7),
+                direction=Grid4TransitionsEnum.WEST,
+                action=None
+            ),
+            Replay(
+                position=(3, 6),
+                direction=Grid4TransitionsEnum.WEST,
+                action=RailEnvActions.MOVE_LEFT
+            ),
+            Replay(
+                position=(3, 6),
+                direction=Grid4TransitionsEnum.WEST,
+                action=None
+            ),
+            Replay(
+                position=(4, 6),
+                direction=Grid4TransitionsEnum.SOUTH,
+                action=RailEnvActions.STOP_MOVING
+            ),
+            Replay(
+                position=(4, 6),
+                direction=Grid4TransitionsEnum.SOUTH,
+                action=RailEnvActions.STOP_MOVING
+            ),
+            Replay(
+                position=(4, 6),
+                direction=Grid4TransitionsEnum.SOUTH,
+                action=RailEnvActions.MOVE_FORWARD
+            ),
+            Replay(
+                position=(4, 6),
+                direction=Grid4TransitionsEnum.SOUTH,
+                action=None
+            ),
+            Replay(
+                position=(5, 6),
+                direction=Grid4TransitionsEnum.SOUTH,
+                action=RailEnvActions.MOVE_FORWARD
+            ),
+
+        ],
+        target=(3, 0),  # west dead-end
+        speed=0.5
+    )
+
+    # TODO test penalties!
+    agentStatic: EnvAgentStatic = env.agents_static[0]
+    info_dict = {
+        'action_required': [True]
+    }
+    for i, replay in enumerate(test_config.replay):
+        if i == 0:
+            # set the initial position
+            agentStatic.position = replay.position
+            agentStatic.direction = replay.direction
+            agentStatic.target = test_config.target
+            agentStatic.moving = True
+            agentStatic.speed_data['speed'] = test_config.speed
+
+            # reset to set agents from agents_static
+            env.reset(False, False)
+
+        def _assert(actual, expected, msg):
+            assert actual == expected, "[{}] {}:  actual={}, expected={}".format(i, msg, actual, expected)
+
+        agent: EnvAgent = env.agents[0]
+
+        _assert(agent.position, replay.position, 'position')
+        _assert(agent.direction, replay.direction, 'direction')
+
+        if replay.action:
+            assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
+            _, _, _, info_dict = env.step({0: replay.action})
+
+        else:
+            assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
+            _, _, _, info_dict = env.step({})
+
+        if rendering:
+            renderer.render_env(show=True, show_observations=True)
 
-    @attrs
-    class TestConfig(object):
-        replay = attrib(type=List[Replay])
-        target = attrib()
-        speed = attrib(type=float)
+
+def test_multispeed_actions_no_malfunction_blocking(rendering=True):
+    """The second agent blocks the first because it is slower."""
+    rail, rail_map = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
+                  number_of_agents=2,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  )
+
+    # initialize agents_static
+    env.reset()
 
     # reset to set agents from agents_static
     env.reset(False, False)
@@ -134,85 +266,156 @@ def test_multispeed_actions_no_malfunction(rendering=True):
         TestConfig(
             replay=[
                 Replay(
-                    position=(3, 9),  # east dead-end
-                    direction=Grid4TransitionsEnum.EAST,
+                    position=(3, 8),
+                    direction=Grid4TransitionsEnum.WEST,
                     action=RailEnvActions.MOVE_FORWARD
                 ),
                 Replay(
-                    position=(3, 9),
-                    direction=Grid4TransitionsEnum.EAST,
+                    position=(3, 8),
+                    direction=Grid4TransitionsEnum.WEST,
                     action=None
                 ),
                 Replay(
                     position=(3, 8),
                     direction=Grid4TransitionsEnum.WEST,
+                    action=None
+                ),
+
+                Replay(
+                    position=(3, 7),
+                    direction=Grid4TransitionsEnum.WEST,
                     action=RailEnvActions.MOVE_FORWARD
                 ),
                 Replay(
-                    position=(3, 8),
+                    position=(3, 7),
                     direction=Grid4TransitionsEnum.WEST,
                     action=None
                 ),
                 Replay(
                     position=(3, 7),
                     direction=Grid4TransitionsEnum.WEST,
+                    action=None
+                ),
+
+                Replay(
+                    position=(3, 6),
+                    direction=Grid4TransitionsEnum.WEST,
                     action=RailEnvActions.MOVE_FORWARD
                 ),
                 Replay(
-                    position=(3, 7),
+                    position=(3, 6),
                     direction=Grid4TransitionsEnum.WEST,
                     action=None
                 ),
                 Replay(
                     position=(3, 6),
                     direction=Grid4TransitionsEnum.WEST,
-                    action=RailEnvActions.MOVE_LEFT
+                    action=None
                 ),
+
                 Replay(
-                    position=(3, 6),
+                    position=(3, 5),
+                    direction=Grid4TransitionsEnum.WEST,
+                    action=RailEnvActions.MOVE_FORWARD
+                ),
+                Replay(
+                    position=(3, 5),
                     direction=Grid4TransitionsEnum.WEST,
                     action=None
                 ),
                 Replay(
-                    position=(4, 6),
-                    direction=Grid4TransitionsEnum.SOUTH,
-                    action=RailEnvActions.STOP_MOVING
+                    position=(3, 5),
+                    direction=Grid4TransitionsEnum.WEST,
+                    action=None
+                )
+            ],
+            target=(3, 0),  # west dead-end
+            speed=1 / 3),
+        TestConfig(
+            replay=[
+                Replay(
+                    position=(3, 9),  # east dead-end
+                    direction=Grid4TransitionsEnum.EAST,
+                    action=RailEnvActions.MOVE_FORWARD
                 ),
                 Replay(
-                    position=(4, 6),
-                    direction=Grid4TransitionsEnum.SOUTH,
-                    action=RailEnvActions.STOP_MOVING
+                    position=(3, 9),
+                    direction=Grid4TransitionsEnum.EAST,
+                    action=None
                 ),
+                # blocked although fraction >= 1.0
                 Replay(
-                    position=(4, 6),
-                    direction=Grid4TransitionsEnum.SOUTH,
+                    position=(3, 9),
+                    direction=Grid4TransitionsEnum.EAST,
+                    action=None
+                ),
+
+                Replay(
+                    position=(3, 8),
+                    direction=Grid4TransitionsEnum.WEST,
                     action=RailEnvActions.MOVE_FORWARD
                 ),
                 Replay(
-                    position=(4, 6),
-                    direction=Grid4TransitionsEnum.SOUTH,
+                    position=(3, 8),
+                    direction=Grid4TransitionsEnum.WEST,
                     action=None
                 ),
+                # blocked although fraction >= 1.0
                 Replay(
-                    position=(5, 6),
-                    direction=Grid4TransitionsEnum.SOUTH,
+                    position=(3, 8),
+                    direction=Grid4TransitionsEnum.WEST,
+                    action=None
+                ),
+
+                Replay(
+                    position=(3, 7),
+                    direction=Grid4TransitionsEnum.WEST,
                     action=RailEnvActions.MOVE_FORWARD
                 ),
+                Replay(
+                    position=(3, 7),
+                    direction=Grid4TransitionsEnum.WEST,
+                    action=None
+                ),
+                # blocked although fraction >= 1.0
+                Replay(
+                    position=(3, 7),
+                    direction=Grid4TransitionsEnum.WEST,
+                    action=None
+                ),
 
+                Replay(
+                    position=(3, 6),
+                    direction=Grid4TransitionsEnum.WEST,
+                    action=RailEnvActions.MOVE_LEFT
+                ),
+                Replay(
+                    position=(3, 6),
+                    direction=Grid4TransitionsEnum.WEST,
+                    action=None
+                ),
+                # not blocked, action required!
+                Replay(
+                    position=(4, 6),
+                    direction=Grid4TransitionsEnum.SOUTH,
+                    action=RailEnvActions.MOVE_FORWARD
+                ),
             ],
             target=(3, 0),  # west dead-end
             speed=0.5
         )
+
     ]
 
     # TODO test penalties!
-    agentStatic: EnvAgentStatic = env.agents_static[0]
-    for test_config in test_configs:
-        info_dict = {
-            'action_required': [True]
-        }
-        for i, replay in enumerate(test_config.replay):
-            if i == 0:
+    info_dict = {
+        'action_required': [True for _ in test_configs]
+    }
+    for step in range(len(test_configs[0].replay)):
+        if step == 0:
+            for a, test_config in enumerate(test_configs):
+                agentStatic: EnvAgentStatic = env.agents_static[a]
+                replay = test_config.replay[0]
                 # set the initial position
                 agentStatic.position = replay.position
                 agentStatic.direction = replay.direction
@@ -220,24 +423,177 @@ def test_multispeed_actions_no_malfunction(rendering=True):
                 agentStatic.moving = True
                 agentStatic.speed_data['speed'] = test_config.speed
 
-                # reset to set agents from agents_static
-                env.reset(False, False)
+            # reset to set agents from agents_static
+            env.reset(False, False)
 
-            def _assert(actual, expected, msg):
-                assert actual == expected, "[{}] {}:  actual={}, expected={}".format(i, msg, actual, expected)
+        def _assert(a, actual, expected, msg):
+            assert actual == expected, "[{}] {} {}:  actual={}, expected={}".format(step, a, msg, actual, expected)
 
-            agent: EnvAgent = env.agents[0]
+        action_dict = {}
+
+        for a, test_config in enumerate(test_configs):
+            agent: EnvAgent = env.agents[a]
+            replay = test_config.replay[step]
+
+            _assert(a, agent.position, replay.position, 'position')
+            _assert(a, agent.direction, replay.direction, 'direction')
 
-            _assert(agent.position, replay.position, 'position')
-            _assert(agent.direction, replay.direction, 'direction')
 
-            if replay.action:
-                assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
-                _, _, _, info_dict = env.step({0: replay.action})
 
+            if replay.action:
+                assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format(step, a, True)
+                action_dict[a] = replay.action
             else:
-                assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
-                _, _, _, info_dict = env.step({})
+                assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format(step, a, False)
+        _, _, _, info_dict = env.step(action_dict)
+
+        if rendering:
+            renderer.render_env(show=True, show_observations=True)
+
 
-            if rendering:
-                renderer.render_env(show=True, show_observations=True)
+def test_multispeed_actions_malfunction_no_blocking(rendering=True):
+    """Test on a single agent whether action on cell exit work correctly despite malfunction."""
+    rail, rail_map = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
+                  number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  )
+
+    # initialize agents_static
+    env.reset()
+
+    # reset to set agents from agents_static
+    env.reset(False, False)
+
+    if rendering:
+        renderer = RenderTool(env, gl="PILSVG")
+
+    test_config = TestConfig(
+        replay=[
+            Replay(
+                position=(3, 9),  # east dead-end
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.MOVE_FORWARD
+            ),
+            Replay(
+                position=(3, 9),
+                direction=Grid4TransitionsEnum.EAST,
+                action=None
+            ),
+            Replay(
+                position=(3, 8),
+                direction=Grid4TransitionsEnum.WEST,
+                action=RailEnvActions.MOVE_FORWARD
+            ),
+            # add additional step in the cell
+            Replay(
+                position=(3, 8),
+                direction=Grid4TransitionsEnum.WEST,
+                action=None,
+                malfunction=2 # recovers in two steps from now!
+            ),
+            # agent recovers in this step
+            Replay(
+                position=(3, 8),
+                direction=Grid4TransitionsEnum.WEST,
+                action=None
+            ),
+            Replay(
+                position=(3, 7),
+                direction=Grid4TransitionsEnum.WEST,
+                action=RailEnvActions.MOVE_FORWARD
+            ),
+            Replay(
+                position=(3, 7),
+                direction=Grid4TransitionsEnum.WEST,
+                action=None
+            ),
+            Replay(
+                position=(3, 6),
+                direction=Grid4TransitionsEnum.WEST,
+                action=RailEnvActions.MOVE_FORWARD,
+                malfunction=2 # recovers in two steps from now!
+            ),
+            # agent recovers in this step; since we're at the beginning, we provide a different action although we're broken!
+            Replay(
+                position=(3, 6),
+                direction=Grid4TransitionsEnum.WEST,
+                action=RailEnvActions.MOVE_LEFT,
+            ),
+            Replay(
+                position=(3, 6),
+                direction=Grid4TransitionsEnum.WEST,
+                action=None
+            ),
+            Replay(
+                position=(4, 6),
+                direction=Grid4TransitionsEnum.SOUTH,
+                action=RailEnvActions.STOP_MOVING
+            ),
+            Replay(
+                position=(4, 6),
+                direction=Grid4TransitionsEnum.SOUTH,
+                action=RailEnvActions.STOP_MOVING
+            ),
+            Replay(
+                position=(4, 6),
+                direction=Grid4TransitionsEnum.SOUTH,
+                action=RailEnvActions.MOVE_FORWARD
+            ),
+            Replay(
+                position=(4, 6),
+                direction=Grid4TransitionsEnum.SOUTH,
+                action=None
+            ),
+            Replay(
+                position=(5, 6),
+                direction=Grid4TransitionsEnum.SOUTH,
+                action=RailEnvActions.MOVE_FORWARD
+            ),
+
+        ],
+        target=(3, 0),  # west dead-end
+        speed=0.5
+    )
+
+    # TODO test penalties!
+    agentStatic: EnvAgentStatic = env.agents_static[0]
+    info_dict = {
+        'action_required': [True]
+    }
+    for i, replay in enumerate(test_config.replay):
+        if i == 0:
+            # set the initial position
+            agentStatic.position = replay.position
+            agentStatic.direction = replay.direction
+            agentStatic.target = test_config.target
+            agentStatic.moving = True
+            agentStatic.speed_data['speed'] = test_config.speed
+
+            # reset to set agents from agents_static
+            env.reset(False, False)
+
+        def _assert(actual, expected, msg):
+            assert actual == expected, "[{}] {}:  actual={}, expected={}".format(i, msg, actual, expected)
+
+        agent: EnvAgent = env.agents[0]
+
+        _assert(agent.position, replay.position, 'position')
+        _assert(agent.direction, replay.direction, 'direction')
+
+        if replay.malfunction:
+            agent.malfunction_data['malfunction'] = 2
+
+        if replay.action:
+            assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
+            _, _, _, info_dict = env.step({0: replay.action})
+
+        else:
+            assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
+            _, _, _, info_dict = env.step({})
+
+        if rendering:
+            renderer.render_env(show=True, show_observations=True)
-- 
GitLab