From 29cc5cda3c7773112b46ab16f0a3e8fd0d6a17d6 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Mon, 16 Sep 2019 19:22:04 +0200
Subject: [PATCH] #178 bugfix initial malfunction

---
 flatland/envs/rail_env.py          | 116 +++++++++++++++++++++--------
 tests/test_flatland_malfunction.py | 111 ++++++++++++++++++++++++++-
 tests/test_multi_speed.py          |  32 +++-----
 tests/test_utils.py                |  21 ++++++
 4 files changed, 223 insertions(+), 57 deletions(-)
 create mode 100644 tests/test_utils.py

diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index e4d69306..b7483b02 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -4,13 +4,14 @@ Definition of the RailEnv environment.
 # TODO:  _ this is a global method --> utils or remove later
 import warnings
 from enum import IntEnum
-from typing import List
+from typing import List, Set, NamedTuple
 
 import msgpack
 import msgpack_numpy as m
 import numpy as np
 
 from flatland.core.env import Environment
+from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
@@ -39,6 +40,11 @@ class RailEnvActions(IntEnum):
         }[a]
 
 
+RailEnvGridPos = NamedTuple('RailEnvGridPos', [('r', int), ('c', int)])
+RailEnvNextAction = NamedTuple('RailEnvNextAction', [('action', RailEnvActions), ('next_position', RailEnvGridPos),
+                                                     ('next_direction', Grid4TransitionsEnum)])
+
+
 class RailEnv(Environment):
     """
     RailEnv environment class.
@@ -262,7 +268,18 @@ class RailEnv(Environment):
 
             agent.malfunction_data['malfunction'] = 0
 
-            self._agent_new_malfunction(i_agent, RailEnvActions.DO_NOTHING)
+            initial_malfunction = self._agent_new_malfunction(i_agent)
+            if initial_malfunction:
+                valid_actions = set(map(lambda x: x.action, self.get_valid_move_actions(agent)))
+                if RailEnvActions.MOVE_FORWARD in valid_actions:
+                    agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
+                elif RailEnvActions.MOVE_LEFT in valid_actions:
+                    agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_LEFT
+                elif RailEnvActions.MOVE_RIGHT in valid_actions:
+                    agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_RIGHT
+                else:
+                    raise Exception(
+                        "Agent {} cannot move forward/left/right from initial position".format(agent.handle))
 
         self.num_resets += 1
         self._elapsed_steps = 0
@@ -277,7 +294,7 @@ class RailEnv(Environment):
         # Return the new observation vectors for each agent
         return self._get_observations()
 
-    def _agent_new_malfunction(self, i_agent, action) -> bool:
+    def _agent_new_malfunction(self, i_agent) -> bool:
         """
         Returns true if the agent enters into malfunction. (False, if not broken down or already broken down before).
         """
@@ -335,25 +352,25 @@ class RailEnv(Environment):
             agent.old_direction = agent.direction
             agent.old_position = agent.position
 
-            # No action has been supplied for this agent -> set DO_NOTHING as default
-            if i_agent not in action_dict_:
-                action = RailEnvActions.DO_NOTHING
-            else:
-                action = action_dict_[i_agent]
-
-            if action < 0 or action > len(RailEnvActions):
-                print('ERROR: illegal action=', action,
-                      'for agent with index=', i_agent,
-                      '"DO NOTHING" will be executed instead')
-                action = RailEnvActions.DO_NOTHING
-
             # Check if agent breaks at this step
-            new_malfunction = self._agent_new_malfunction(i_agent, action)
+            new_malfunction = self._agent_new_malfunction(i_agent)
 
             # Is the agent at the beginning of the cell? Then, it can take an action
             # Design choice (Erik+Christian):
             #  as long as we're broken down at the beginning of the cell, we can choose other actions!
             if agent.speed_data['position_fraction'] == 0.0:
+                # No action has been supplied for this agent -> set DO_NOTHING as default
+                if i_agent not in action_dict_:
+                    action = RailEnvActions.DO_NOTHING
+                else:
+                    action = action_dict_[i_agent]
+
+                if action < 0 or action > len(RailEnvActions):
+                    print('ERROR: illegal action=', action,
+                          'for agent with index=', i_agent,
+                          '"DO NOTHING" will be executed instead')
+                    action = RailEnvActions.DO_NOTHING
+
                 if action == RailEnvActions.DO_NOTHING and agent.moving:
                     # Keep moving
                     action = RailEnvActions.MOVE_FORWARD
@@ -370,12 +387,14 @@ class RailEnv(Environment):
                     self.rewards_dict[i_agent] += self.start_penalty
 
                 # Store the action
-                if agent.moving and action not in [RailEnvActions.DO_NOTHING, RailEnvActions.STOP_MOVING]:
+                if agent.moving:
+                    _action_stored = False
                     _, new_cell_valid, new_direction, new_position, transition_valid = \
                         self._check_action_on_agent(action, agent)
 
                     if all([new_cell_valid, transition_valid]):
                         agent.speed_data['transition_action_on_cellexit'] = action
+                        _action_stored = True
                     else:
                         # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
                         # try to keep moving forward!
@@ -385,19 +404,14 @@ class RailEnv(Environment):
 
                             if all([new_cell_valid, transition_valid]):
                                 agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
-                            else:
-                                # If the agent cannot move due to an invalid transition, we set its state to not moving
-                                self.rewards_dict[i_agent] += self.invalid_action_penalty
-                                self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
-                                self.rewards_dict[i_agent] += self.stop_penalty
-                                agent.moving = False
-
-                        else:
-                            # If the agent cannot move due to an invalid transition, we set its state to not moving
-                            self.rewards_dict[i_agent] += self.invalid_action_penalty
-                            self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
-                            self.rewards_dict[i_agent] += self.stop_penalty
-                            agent.moving = False
+                                _action_stored = True
+
+                    if not _action_stored:
+                        # If the agent cannot move due to an invalid transition, we set its state to not moving
+                        self.rewards_dict[i_agent] += self.invalid_action_penalty
+                        self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
+                        self.rewards_dict[i_agent] += self.stop_penalty
+                        agent.moving = False
 
             # if we've just broken in this step, nothing else to do
             if new_malfunction:
@@ -410,7 +424,6 @@ class RailEnv(Environment):
                 if agent.malfunction_data['malfunction'] < 2:
                     agent.malfunction_data['malfunction'] -= 1
                     self.agents[i_agent].moving = True
-                    action = RailEnvActions.DO_NOTHING
 
                 else:
                     agent.malfunction_data['malfunction'] -= 1
@@ -438,6 +451,9 @@ class RailEnv(Environment):
                     cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent(
                         agent.speed_data['transition_action_on_cellexit'], agent)
 
+                    # N.B. validity of new_cell and transition should have been verified before the action was stored!
+                    assert new_cell_valid
+                    assert transition_valid
                     if cell_free:
                         agent.position = new_position
                         agent.direction = new_direction
@@ -532,6 +548,44 @@ class RailEnv(Environment):
                 transition_valid = True
         return new_direction, transition_valid
 
+    def get_valid_move_actions(self, agent: EnvAgent) -> Set[RailEnvNextAction]:
+        valid_actions: Set[RailEnvNextAction] = set()
+        agent_position = agent.position
+        agent_direction = agent.direction
+        possible_transitions = self.rail.get_transitions(*agent_position, agent_direction)
+        num_transitions = np.count_nonzero(possible_transitions)
+
+        # Start from the current orientation, and see which transitions are available;
+        # organize them as [left, forward, right], relative to the current orientation
+        # If only one transition is possible, the forward branch is aligned with it.
+        if self.rail.is_dead_end(agent_position):
+            action = RailEnvActions.MOVE_FORWARD
+            exit_direction = (agent_direction + 2) % 4
+            if possible_transitions[exit_direction]:
+                new_position = get_new_position(agent_position, exit_direction)
+                valid_actions.add(RailEnvNextAction(action, new_position, exit_direction))
+        elif num_transitions == 1:
+            action = RailEnvActions.MOVE_FORWARD
+            for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
+                if possible_transitions[new_direction]:
+                    new_position = get_new_position(agent_position, new_direction)
+                    valid_actions.add(RailEnvNextAction(action, new_position, new_direction))
+        else:
+            for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
+                if possible_transitions[new_direction]:
+                    if new_direction == agent_direction:
+                        action = RailEnvActions.MOVE_FORWARD
+                    elif new_direction == (agent_direction + 1) % 4:
+                        action = RailEnvActions.MOVE_RIGHT
+                    elif new_direction == (agent_direction - 1) % 4:
+                        action = RailEnvActions.MOVE_LEFT
+                    else:
+                        raise Exception("Illegal state")
+
+                    new_position = get_new_position(agent_position, new_direction)
+                    valid_actions.add(RailEnvNextAction(action, new_position, new_direction))
+        return valid_actions
+
     def _get_observations(self):
         self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
         return self.obs_dict
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index a63e9722..e74666e1 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -1,9 +1,15 @@
+import random
+
 import numpy as np
 
+from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.envs.agent_utils import EnvAgent
 from flatland.envs.observations import TreeObsForRailEnv
-from flatland.envs.rail_env import RailEnv
-from flatland.envs.rail_generators import complex_rail_generator
-from flatland.envs.schedule_generators import complex_schedule_generator
+from flatland.envs.rail_env import RailEnv, RailEnvActions
+from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator
+from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator
+from flatland.utils.rendertools import RenderTool
+from test_utils import Replay
 
 
 class SingleAgentNavigationObs(TreeObsForRailEnv):
@@ -145,3 +151,102 @@ def test_malfunction_process_statistically():
     # check that generation of malfunctions works as expected
     # results are different in py36 and py37, therefore no exact test on nb_malfunction
     assert nb_malfunction > 150
+
+
+# TODO test DO_NOTHING!
+def test_initial_malfunction(rendering=True):
+    random.seed(0)
+    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
+                       'malfunction_rate': 70,  # Rate of malfunction occurence
+                       'min_duration': 2,  # Minimal duration of malfunction
+                       'max_duration': 5  # Max duration of malfunction
+                       }
+
+    speed_ration_map = {1.: 1.,  # Fast passenger train
+                        1. / 2.: 0.,  # Fast freight train
+                        1. / 3.: 0.,  # Slow commuter train
+                        1. / 4.: 0.}  # Slow freight train
+
+    env = RailEnv(width=25,
+                  height=30,
+                  rail_generator=sparse_rail_generator(num_cities=5,
+                                                       # Number of cities in map (where train stations are)
+                                                       num_intersections=4,
+                                                       # Number of intersections (no start / target)
+                                                       num_trainstations=25,  # Number of possible start/targets on map
+                                                       min_node_dist=6,  # Minimal distance of nodes
+                                                       node_radius=3,  # Proximity of stations to city center
+                                                       num_neighb=3,
+                                                       # Number of connections to other cities/intersections
+                                                       seed=215545,  # Random seed
+                                                       grid_mode=True,
+                                                       enhance_intersection=False
+                                                       ),
+                  schedule_generator=sparse_schedule_generator(speed_ration_map),
+                  number_of_agents=1,
+                  stochastic_data=stochastic_data,  # Malfunction data generator
+                  )
+
+    if rendering:
+        renderer = RenderTool(env)
+        renderer.render_env(show=True, frames=False, show_observations=False)
+    _action = dict()
+
+    replay_steps = [
+        Replay(
+            position=(27, 5),
+            direction=Grid4TransitionsEnum.EAST,
+            action=RailEnvActions.MOVE_FORWARD,
+            malfunction=3
+        ),
+        Replay(
+            position=(27, 5),
+            direction=Grid4TransitionsEnum.EAST,
+            action=RailEnvActions.MOVE_FORWARD,
+            malfunction=2
+        ),
+        Replay(
+            position=(27, 5),
+            direction=Grid4TransitionsEnum.EAST,
+            action=RailEnvActions.MOVE_FORWARD,
+            malfunction=1
+        ),
+        Replay(
+            position=(27, 4),
+            direction=Grid4TransitionsEnum.WEST,
+            action=RailEnvActions.MOVE_FORWARD,
+            malfunction=0
+        ),
+        Replay(
+            position=(27, 3),
+            direction=Grid4TransitionsEnum.WEST,
+            action=RailEnvActions.MOVE_FORWARD,
+            malfunction=0
+        )
+    ]
+
+    info_dict = {
+        'action_required': [True]
+    }
+
+    for i, replay in enumerate(replay_steps):
+
+        def _assert(actual, expected, msg):
+            assert actual == expected, "[{}] {}:  actual={}, expected={}".format(i, msg, actual, expected)
+
+        agent: EnvAgent = env.agents[0]
+
+        _assert(agent.position, replay.position, 'position')
+        _assert(agent.direction, replay.direction, 'direction')
+        _assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
+
+        if replay.action:
+            assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
+            _, _, _, info_dict = env.step({0: replay.action})
+
+        else:
+            assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
+            _, _, _, info_dict = env.step({})
+
+        if rendering:
+            renderer.render_env(show=True, show_observations=True)
diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py
index 86edc08c..529e9412 100644
--- a/tests/test_multi_speed.py
+++ b/tests/test_multi_speed.py
@@ -1,7 +1,6 @@
-from typing import List
+import time
 
 import numpy as np
-from attr import attrib, attrs
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.envs.agent_utils import EnvAgent, EnvAgentStatic
@@ -12,6 +11,7 @@ from flatland.envs.rail_generators import complex_rail_generator, rail_from_grid
 from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator
 from flatland.utils.rendertools import RenderTool
 from flatland.utils.simple_rail import make_simple_rail
+from test_utils import TestConfig, Replay
 
 np.random.seed(1)
 
@@ -97,21 +97,6 @@ def test_multi_speed_init():
                 old_pos[i_agent] = env.agents[i_agent].position
 
 
-@attrs
-class Replay(object):
-    position = attrib()
-    direction = attrib()
-    action = attrib(type=RailEnvActions)
-    malfunction = attrib(default=0, type=int)
-
-
-@attrs
-class TestConfig(object):
-    replay = attrib(type=List[Replay])
-    target = attrib()
-    speed = attrib(type=float)
-
-
 def test_multispeed_actions_no_malfunction_no_blocking(rendering=True):
     """Test that actions are correctly performed on cell exit for a single agent."""
     rail, rail_map = make_simple_rail()
@@ -179,6 +164,7 @@ def test_multispeed_actions_no_malfunction_no_blocking(rendering=True):
                 direction=Grid4TransitionsEnum.SOUTH,
                 action=RailEnvActions.STOP_MOVING
             ),
+            #
             Replay(
                 position=(4, 6),
                 direction=Grid4TransitionsEnum.SOUTH,
@@ -438,13 +424,13 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True):
             _assert(a, agent.position, replay.position, 'position')
             _assert(a, agent.direction, replay.direction, 'direction')
 
-
-
             if replay.action:
-                assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format(step, a, True)
+                assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format(
+                    step, a, True)
                 action_dict[a] = replay.action
             else:
-                assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format(step, a, False)
+                assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format(
+                    step, a, False)
         _, _, _, info_dict = env.step(action_dict)
 
         if rendering:
@@ -493,7 +479,7 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True):
                 position=(3, 8),
                 direction=Grid4TransitionsEnum.WEST,
                 action=None,
-                malfunction=2 # recovers in two steps from now!
+                malfunction=2  # recovers in two steps from now!
             ),
             # agent recovers in this step
             Replay(
@@ -515,7 +501,7 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True):
                 position=(3, 6),
                 direction=Grid4TransitionsEnum.WEST,
                 action=RailEnvActions.MOVE_FORWARD,
-                malfunction=2 # recovers in two steps from now!
+                malfunction=2  # recovers in two steps from now!
             ),
             # agent recovers in this step; since we're at the beginning, we provide a different action although we're broken!
             Replay(
diff --git a/tests/test_utils.py b/tests/test_utils.py
new file mode 100644
index 00000000..4bd84e76
--- /dev/null
+++ b/tests/test_utils.py
@@ -0,0 +1,21 @@
+"""Test Utils."""
+from typing import List
+
+from attr import attrs, attrib
+
+from flatland.envs.rail_env import RailEnvActions
+
+
+@attrs
+class Replay(object):
+    position = attrib()
+    direction = attrib()
+    action = attrib(type=RailEnvActions)
+    malfunction = attrib(default=0, type=int)
+
+
+@attrs
+class TestConfig(object):
+    replay = attrib(type=List[Replay])
+    target = attrib()
+    speed = attrib(type=float)
-- 
GitLab