From 0a7ebfd07b92710c355f105fffea4d9da12108cb Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Wed, 18 Sep 2019 17:51:33 +0200
Subject: [PATCH] #178 bugfix step function intial malfunction

---
 flatland/envs/agent_utils.py       |  14 +-
 flatland/envs/rail_env.py          | 361 ++++++++++++++++-------------
 tests/test_flatland_malfunction.py | 230 +++++++++++++++++-
 tests/test_multi_speed.py          |   5 +-
 4 files changed, 440 insertions(+), 170 deletions(-)

diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index b228e10b..f659ec84 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -1,8 +1,11 @@
 from itertools import starmap
+from typing import Tuple
 
 import numpy as np
 from attr import attrs, attrib, Factory
 
+from flatland.core.grid.grid4 import Grid4TransitionsEnum
+
 
 @attrs
 class EnvAgentStatic(object):
@@ -11,10 +14,10 @@ class EnvAgentStatic(object):
         rather than where it is at the moment.
         The target should also be stored here.
     """
-    position = attrib()
-    direction = attrib()
-    target = attrib()
-    moving = attrib(default=False)
+    position = attrib(type=Tuple[int, int])
+    direction = attrib(type=Grid4TransitionsEnum)
+    target = attrib(type=Tuple[int, int])
+    moving = attrib(default=False, type=bool)
 
     # speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0,
     # after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous
@@ -27,7 +30,8 @@ class EnvAgentStatic(object):
     # number of time the agent had to stop, since the last time it broke down
     malfunction_data = attrib(
         default=Factory(
-            lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0})))
+            lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0,
+                          'moving_before_malfunction': False})))
 
     @classmethod
     def from_lists(cls, positions, directions, targets, speeds=None, malfunction_rates=None):
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index df50b813..11df8b20 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -4,7 +4,7 @@ Definition of the RailEnv environment.
 # TODO:  _ this is a global method --> utils or remove later
 import warnings
 from enum import IntEnum
-from typing import List, Set, NamedTuple, Optional
+from typing import List, Set, NamedTuple, Optional, Tuple, Dict
 
 import msgpack
 import msgpack_numpy as m
@@ -122,7 +122,7 @@ class RailEnv(Environment):
         Environment init.
 
         Parameters
-        -------
+        ----------
         rail_generator : function
             The rail_generator function is a function that takes the width,
             height and agents handles of a  rail environment, along with the number of times
@@ -147,9 +147,6 @@ class RailEnv(Environment):
             ObservationBuilder-derived object that takes builds observation
             vectors for each agent.
         max_episode_steps : int or None
-
-        file_name: you can load a pickle file. from previously saved *.pkl file
-
         """
         super().__init__()
 
@@ -272,18 +269,10 @@ class RailEnv(Environment):
 
             agent.malfunction_data['malfunction'] = 0
 
-            initial_malfunction = self._agent_new_malfunction(i_agent)
+            initial_malfunction = self._agent_malfunction(i_agent)
+
             if initial_malfunction:
-                valid_actions = set(map(lambda x: x.action, self.get_valid_move_actions(agent)))
-                if RailEnvActions.MOVE_FORWARD in valid_actions:
-                    agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
-                elif RailEnvActions.MOVE_LEFT in valid_actions:
-                    agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_LEFT
-                elif RailEnvActions.MOVE_RIGHT in valid_actions:
-                    agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_RIGHT
-                else:
-                    raise Exception(
-                        "Agent {} cannot move forward/left/right from initial position".format(agent.handle))
+                agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING
 
         self.num_resets += 1
         self._elapsed_steps = 0
@@ -299,7 +288,7 @@ class RailEnv(Environment):
         # Return the new observation vectors for each agent
         return self._get_observations()
 
-    def _agent_new_malfunction(self, i_agent) -> bool:
+    def _agent_malfunction(self, i_agent) -> bool:
         """
         Returns true if the agent enters into malfunction. (False, if not broken down or already broken down before).
         """
@@ -326,12 +315,28 @@ class RailEnv(Environment):
             num_broken_steps = np.random.randint(self.min_number_of_steps_broken,
                                                  self.max_number_of_steps_broken + 1) + 1
             agent.malfunction_data['malfunction'] = num_broken_steps
+            agent.malfunction_data['moving_before_malfunction'] = agent.moving
 
             return True
+        else:
+            # The train was broken before...
+            if agent.malfunction_data['malfunction'] > 0:
+
+                # Last step of malfunction --> Agent starts moving again after getting fixed
+                if agent.malfunction_data['malfunction'] < 2:
+                    agent.malfunction_data['malfunction'] -= 1
+
+                    # restore moving state before malfunction without further penalty
+                    self.agents[i_agent].moving = agent.malfunction_data['moving_before_malfunction']
+
+                else:
+                    agent.malfunction_data['malfunction'] -= 1
+
+                    # Nothing left to do with broken agent
+                    return True
         return False
 
-    # TODO refactor to decrease length of this method!
-    def step(self, action_dict_):
+    def step(self, action_dict_: Dict[int, RailEnvActions]):
         self._elapsed_steps += 1
 
         # Reset the step rewards
@@ -349,126 +354,7 @@ class RailEnv(Environment):
             return self._get_observations(), self.rewards_dict, self.dones, info_dict
 
         for i_agent in range(self.get_num_agents()):
-
-            if self.dones[i_agent]:  # this agent has already completed...
-                continue
-
-            agent = self.agents[i_agent]
-            agent.old_direction = agent.direction
-            agent.old_position = agent.position
-
-            # Check if agent breaks at this step
-            new_malfunction = self._agent_new_malfunction(i_agent)
-
-            # Is the agent at the beginning of the cell? Then, it can take an action
-            # Design choice (Erik+Christian):
-            #  as long as we're broken down at the beginning of the cell, we can choose other actions!
-            if agent.speed_data['position_fraction'] == 0.0:
-                # No action has been supplied for this agent -> set DO_NOTHING as default
-                if i_agent not in action_dict_:
-                    action = RailEnvActions.DO_NOTHING
-                else:
-                    action = action_dict_[i_agent]
-
-                if action < 0 or action > len(RailEnvActions):
-                    print('ERROR: illegal action=', action,
-                          'for agent with index=', i_agent,
-                          '"DO NOTHING" will be executed instead')
-                    action = RailEnvActions.DO_NOTHING
-
-                if action == RailEnvActions.DO_NOTHING and agent.moving:
-                    # Keep moving
-                    action = RailEnvActions.MOVE_FORWARD
-
-                if action == RailEnvActions.STOP_MOVING and agent.moving:
-                    # Only allow halting an agent on entering new cells.
-                    agent.moving = False
-                    self.rewards_dict[i_agent] += self.stop_penalty
-
-                if not agent.moving and not (
-                    action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
-                    # Allow agent to start with any forward or direction action
-                    agent.moving = True
-                    self.rewards_dict[i_agent] += self.start_penalty
-
-                # Store the action
-                if agent.moving:
-                    _action_stored = False
-                    _, new_cell_valid, new_direction, new_position, transition_valid = \
-                        self._check_action_on_agent(action, agent)
-
-                    if all([new_cell_valid, transition_valid]):
-                        agent.speed_data['transition_action_on_cellexit'] = action
-                        _action_stored = True
-                    else:
-                        # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
-                        # try to keep moving forward!
-                        if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT):
-                            _, new_cell_valid, new_direction, new_position, transition_valid = \
-                                self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)
-
-                            if all([new_cell_valid, transition_valid]):
-                                agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
-                                _action_stored = True
-
-                    if not _action_stored:
-                        # If the agent cannot move due to an invalid transition, we set its state to not moving
-                        self.rewards_dict[i_agent] += self.invalid_action_penalty
-                        self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
-                        self.rewards_dict[i_agent] += self.stop_penalty
-                        agent.moving = False
-
-            # if we've just broken in this step, nothing else to do
-            if new_malfunction:
-                continue
-
-            # The train was broken before...
-            if agent.malfunction_data['malfunction'] > 0:
-
-                # Last step of malfunction --> Agent starts moving again after getting fixed
-                if agent.malfunction_data['malfunction'] < 2:
-                    agent.malfunction_data['malfunction'] -= 1
-                    self.agents[i_agent].moving = True
-
-                else:
-                    agent.malfunction_data['malfunction'] -= 1
-
-                    # Broken agents are stopped
-                    self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
-                    self.agents[i_agent].moving = False
-
-                    # Nothing left to do with broken agent
-                    continue
-
-            # Now perform a movement.
-            # If agent.moving, increment the position_fraction by the speed of the agent
-            # If the new position fraction is >= 1, reset to 0, and perform the stored
-            #   transition_action_on_cellexit if the cell is free.
-            if agent.moving:
-
-                agent.speed_data['position_fraction'] += agent.speed_data['speed']
-                if agent.speed_data['position_fraction'] >= 1.0:
-                    # Perform stored action to transition to the next cell as soon as cell is free
-                    # Notice that we've already check new_cell_valid and transition valid when we stored the action,
-                    # so we only have to check cell_free now!
-
-                    # cell and transition validity was checked when we stored transition_action_on_cellexit!
-                    cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent(
-                        agent.speed_data['transition_action_on_cellexit'], agent)
-
-                    # N.B. validity of new_cell and transition should have been verified before the action was stored!
-                    assert new_cell_valid
-                    assert transition_valid
-                    if cell_free:
-                        agent.position = new_position
-                        agent.direction = new_direction
-                        agent.speed_data['position_fraction'] = 0.0
-
-            if np.equal(agent.position, agent.target).all():
-                self.dones[i_agent] = True
-                agent.moving = False
-            else:
-                self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
+            self._step_agent(i_agent, action_dict_)
 
         # Check for end of episode + add global reward to all rewards!
         if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
@@ -496,17 +382,144 @@ class RailEnv(Environment):
 
         return self._get_observations(), self.rewards_dict, self.dones, info_dict
 
-    def _check_action_on_agent(self, action, agent):
+    def _step_agent(self, i_agent, action_dict_: Dict[int, RailEnvActions]):
+        """
+        Performs a step and step, start and stop penalty on a single agent in the following sub steps:
+        - malfunction
+        - action handling if at the beginning of cell
+        - movement
+        Parameters
+        ----------
+        i_agent : int
+        action_dict_ : Dict[int,RailEnvActions]
+
+        """
+        if self.dones[i_agent]:  # this agent has already completed...
+            return
+
+        agent = self.agents[i_agent]
+        agent.old_direction = agent.direction
+        agent.old_position = agent.position
+
+        # is the agent malfunctioning?
+        malfunction = self._agent_malfunction(i_agent)
+
+        # if agent is broken, actions are ignored and agent does not move,
+        # the agent is not penalized in this step!
+        if malfunction:
+            self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
+            return
+
+        # Is the agent at the beginning of the cell? Then, it can take an action.
+        if agent.speed_data['position_fraction'] == 0.0:
+            # No action has been supplied for this agent -> set DO_NOTHING as default
+            if i_agent not in action_dict_:
+                action = RailEnvActions.DO_NOTHING
+            else:
+                action = action_dict_[i_agent]
+
+            if action < 0 or action > len(RailEnvActions):
+                print('ERROR: illegal action=', action,
+                      'for agent with index=', i_agent,
+                      '"DO NOTHING" will be executed instead')
+                action = RailEnvActions.DO_NOTHING
+
+            if action == RailEnvActions.DO_NOTHING and agent.moving:
+                # Keep moving
+                action = RailEnvActions.MOVE_FORWARD
+
+            if action == RailEnvActions.STOP_MOVING and agent.moving:
+                # Only allow halting an agent on entering new cells.
+                agent.moving = False
+                self.rewards_dict[i_agent] += self.stop_penalty
+
+            if not agent.moving and not (
+                action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
+                # Allow agent to start with any forward or direction action
+                agent.moving = True
+                self.rewards_dict[i_agent] += self.start_penalty
+
+            # Store the action if action is moving
+            # If not moving, the action will be stored when the agent starts moving again.
+            if agent.moving:
+                _action_stored = False
+                _, new_cell_valid, new_direction, new_position, transition_valid = \
+                    self._check_action_on_agent(action, agent)
+
+                if all([new_cell_valid, transition_valid]):
+                    agent.speed_data['transition_action_on_cellexit'] = action
+                    _action_stored = True
+                else:
+                    # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
+                    # try to keep moving forward!
+                    if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT):
+                        _, new_cell_valid, new_direction, new_position, transition_valid = \
+                            self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)
+
+                        if all([new_cell_valid, transition_valid]):
+                            agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
+                            _action_stored = True
+
+                if not _action_stored:
+                    # If the agent cannot move due to an invalid transition, we set its state to not moving
+                    self.rewards_dict[i_agent] += self.invalid_action_penalty
+                    self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
+                    self.rewards_dict[i_agent] += self.stop_penalty
+                    agent.moving = False
+
+        # Now perform a movement.
+        # If agent.moving, increment the position_fraction by the speed of the agent
+        # If the new position fraction is >= 1, reset to 0, and perform the stored
+        #   transition_action_on_cellexit if the cell is free.
+        if agent.moving:
+            agent.speed_data['position_fraction'] += agent.speed_data['speed']
+            if agent.speed_data['position_fraction'] >= 1.0:
+                # Perform stored action to transition to the next cell as soon as cell is free
+                # Notice that we've already checked new_cell_valid and transition valid when we stored the action,
+                # so we only have to check cell_free now!
+
+                # cell and transition validity was checked when we stored transition_action_on_cellexit!
+                cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent(
+                    agent.speed_data['transition_action_on_cellexit'], agent)
+
+                # N.B. validity of new_cell and transition should have been verified before the action was stored!
+                assert new_cell_valid
+                assert transition_valid
+                if cell_free:
+                    agent.position = new_position
+                    agent.direction = new_direction
+                    agent.speed_data['position_fraction'] = 0.0
+
+            # has the agent reached its target?
+            if np.equal(agent.position, agent.target).all():
+                self.dones[i_agent] = True
+                agent.moving = False
+            else:
+                self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
+
+    def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent):
+        """
+
+        Parameters
+        ----------
+        action : RailEnvActions
+        agent : EnvAgent
+
+        Returns
+        -------
+        bool
+            Is it a legal move?
+            1) transition allows the new_direction in the cell,
+            2) the new cell is not empty (case 0),
+            3) the cell is free, i.e., no agent is currently in that cell
+
 
+        """
         # compute number of possible transitions in the current
         # cell used to check for invalid actions
         new_direction, transition_valid = self.check_action(agent, action)
         new_position = get_new_position(agent.position, new_direction)
 
-        # Is it a legal move?
-        # 1) transition allows the new_direction in the cell,
-        # 2) the new cell is not empty (case 0),
-        # 3) the cell is free, i.e., no agent is currently in that cell
         new_cell_valid = (
             np.array_equal(  # Check the new position is still in the grid
                 new_position,
@@ -522,11 +535,24 @@ class RailEnv(Environment):
 
         # Check the new position is not the same as any of the existing agent positions
         # (including itself, for simplicity, since it is moving)
-        cell_free = not np.any(
-            np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
+        cell_free = not np.any(np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
         return cell_free, new_cell_valid, new_direction, new_position, transition_valid
 
-    def check_action(self, agent, action):
+    def check_action(self, agent: EnvAgent, action: RailEnvActions):
+        """
+
+        Parameters
+        ----------
+        agent : EnvAgent
+        action : RailEnvActions
+
+        Returns
+        -------
+        Tuple[Grid4TransitionsEnum,Tuple[int,int]]
+
+
+
+        """
         transition_valid = None
         possible_transitions = self.rail.get_transitions(*agent.position, agent.direction)
         num_transitions = np.count_nonzero(possible_transitions)
@@ -544,26 +570,41 @@ class RailEnv(Environment):
 
         new_direction %= 4
 
-        if action == RailEnvActions.MOVE_FORWARD:
-            if num_transitions == 1:
-                # - dead-end, straight line or curved line;
-                # new_direction will be the only valid transition
-                # - take only available transition
-                new_direction = np.argmax(possible_transitions)
-                transition_valid = True
+        if action == RailEnvActions.MOVE_FORWARD and num_transitions == 1:
+            # - dead-end, straight line or curved line;
+            # new_direction will be the only valid transition
+            # - take only available transition
+            new_direction = np.argmax(possible_transitions)
+            transition_valid = True
         return new_direction, transition_valid
 
-    def get_valid_move_actions(self, agent: EnvAgent) -> Set[RailEnvNextAction]:
+    @staticmethod
+    def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
+                                agent_position: Tuple[int, int],
+                                rail: GridTransitionMap) -> Set[RailEnvNextAction]:
+        """
+        Get the valid move actions (forward, left, right) for an agent.
+
+        Parameters
+        ----------
+        agent_direction : Grid4TransitionsEnum
+        agent_position: Tuple[int,int]
+        rail : GridTransitionMap
+
+
+        Returns
+        -------
+        Set of `RailEnvNextAction` (tuples of (action,position,direction))
+            Possible move actions (forward,left,right) and the next position/direction they lead to.
+            It is not checked that the next cell is free.
+        """
         valid_actions: Set[RailEnvNextAction] = OrderedSet()
-        agent_position = agent.position
-        agent_direction = agent.direction
-        possible_transitions = self.rail.get_transitions(*agent_position, agent_direction)
+        possible_transitions = rail.get_transitions(*agent_position, agent_direction)
         num_transitions = np.count_nonzero(possible_transitions)
-
         # Start from the current orientation, and see which transitions are available;
         # organize them as [left, forward, right], relative to the current orientation
         # If only one transition is possible, the forward branch is aligned with it.
-        if self.rail.is_dead_end(agent_position):
+        if rail.is_dead_end(agent_position):
             action = RailEnvActions.MOVE_FORWARD
             exit_direction = (agent_direction + 2) % 4
             if possible_transitions[exit_direction]:
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index 70be7a0e..fde9df58 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -3,6 +3,7 @@ import random
 import numpy as np
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.core.grid.grid4_utils import get_new_position
 from flatland.envs.agent_utils import EnvAgent
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv, RailEnvActions
@@ -46,7 +47,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
             min_distances = []
             for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                 if possible_transitions[direction]:
-                    new_position = self._new_position(agent.position, direction)
+                    new_position = get_new_position(agent.position, direction)
                     min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction])
                 else:
                     min_distances.append(np.inf)
@@ -150,8 +151,7 @@ def test_malfunction_process_statistically():
         env.step(action_dict)
 
     # check that generation of malfunctions works as expected
-    # results are different in py36 and py37, therefore no exact test on nb_malfunction
-    assert nb_malfunction == 149, "nb_malfunction={}".format(nb_malfunction)
+    assert nb_malfunction == 156, "nb_malfunction={}".format(nb_malfunction)
 
 
 def test_initial_malfunction(rendering=True):
@@ -207,6 +207,8 @@ def test_initial_malfunction(rendering=True):
             action=RailEnvActions.MOVE_FORWARD,
             malfunction=2
         ),
+        # malfunction stops in the next step and we're still at the beginning of the cell
+        # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
         Replay(
             position=(28, 5),
             direction=Grid4TransitionsEnum.EAST,
@@ -252,3 +254,225 @@ def test_initial_malfunction(rendering=True):
 
         if rendering:
             renderer.render_env(show=True, show_observations=True)
+
+
+def test_initial_malfunction_stop_moving(rendering=True):
+    random.seed(0)
+    np.random.seed(0)
+
+    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
+                       'malfunction_rate': 70,  # Rate of malfunction occurence
+                       'min_duration': 2,  # Minimal duration of malfunction
+                       'max_duration': 5  # Max duration of malfunction
+                       }
+
+    speed_ration_map = {1.: 1.,  # Fast passenger train
+                        1. / 2.: 0.,  # Fast freight train
+                        1. / 3.: 0.,  # Slow commuter train
+                        1. / 4.: 0.}  # Slow freight train
+
+    env = RailEnv(width=25,
+                  height=30,
+                  rail_generator=sparse_rail_generator(num_cities=5,
+                                                       # Number of cities in map (where train stations are)
+                                                       num_intersections=4,
+                                                       # Number of intersections (no start / target)
+                                                       num_trainstations=25,  # Number of possible start/targets on map
+                                                       min_node_dist=6,  # Minimal distance of nodes
+                                                       node_radius=3,  # Proximity of stations to city center
+                                                       num_neighb=3,
+                                                       # Number of connections to other cities/intersections
+                                                       seed=215545,  # Random seed
+                                                       grid_mode=True,
+                                                       enhance_intersection=False
+                                                       ),
+                  schedule_generator=sparse_schedule_generator(speed_ration_map),
+                  number_of_agents=1,
+                  stochastic_data=stochastic_data,  # Malfunction data generator
+                  )
+
+    if rendering:
+        renderer = RenderTool(env)
+        renderer.render_env(show=True, frames=False, show_observations=False)
+    _action = dict()
+
+    replay_steps = [
+        Replay(
+            position=(28, 5),
+            direction=Grid4TransitionsEnum.EAST,
+            action=RailEnvActions.DO_NOTHING,
+            malfunction=3
+        ),
+        Replay(
+            position=(28, 5),
+            direction=Grid4TransitionsEnum.EAST,
+            action=RailEnvActions.DO_NOTHING,
+            malfunction=2
+        ),
+        # malfunction stops in the next step and we're still at the beginning of the cell
+        # --> if we take action DO_NOTHING, agent should restart without moving
+        #
+        Replay(
+            position=(28, 5),
+            direction=Grid4TransitionsEnum.EAST,
+            action=RailEnvActions.STOP_MOVING,
+            malfunction=1
+        ),
+        # we have stopped and do nothing --> should stand still
+        Replay(
+            position=(28, 5),
+            direction=Grid4TransitionsEnum.EAST,
+            action=RailEnvActions.DO_NOTHING,
+            malfunction=0
+        ),
+        # we start to move forward --> should go to next cell now
+        Replay(
+            position=(28, 5),
+            direction=Grid4TransitionsEnum.EAST,
+            action=RailEnvActions.MOVE_FORWARD,
+            malfunction=0
+        ),
+        Replay(
+            position=(28, 4),
+            direction=Grid4TransitionsEnum.WEST,
+            action=RailEnvActions.MOVE_FORWARD,
+            malfunction=0
+        )
+    ]
+
+    info_dict = {
+        'action_required': [True]
+    }
+
+    for i, replay in enumerate(replay_steps):
+
+        def _assert(actual, expected, msg):
+            assert actual == expected, "[{}] {}:  actual={}, expected={}".format(i, msg, actual, expected)
+
+        agent: EnvAgent = env.agents[0]
+
+        _assert(agent.position, replay.position, 'position')
+        _assert(agent.direction, replay.direction, 'direction')
+        _assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
+
+        if replay.action is not None:
+            assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
+            _, _, _, info_dict = env.step({0: replay.action})
+
+        else:
+            assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
+            _, _, _, info_dict = env.step({})
+
+        if rendering:
+            renderer.render_env(show=True, show_observations=True)
+
+
+def test_initial_malfunction_do_nothing(rendering=True):
+    random.seed(0)
+    np.random.seed(0)
+
+    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
+                       'malfunction_rate': 70,  # Rate of malfunction occurence
+                       'min_duration': 2,  # Minimal duration of malfunction
+                       'max_duration': 5  # Max duration of malfunction
+                       }
+
+    speed_ration_map = {1.: 1.,  # Fast passenger train
+                        1. / 2.: 0.,  # Fast freight train
+                        1. / 3.: 0.,  # Slow commuter train
+                        1. / 4.: 0.}  # Slow freight train
+
+    env = RailEnv(width=25,
+                  height=30,
+                  rail_generator=sparse_rail_generator(num_cities=5,
+                                                       # Number of cities in map (where train stations are)
+                                                       num_intersections=4,
+                                                       # Number of intersections (no start / target)
+                                                       num_trainstations=25,  # Number of possible start/targets on map
+                                                       min_node_dist=6,  # Minimal distance of nodes
+                                                       node_radius=3,  # Proximity of stations to city center
+                                                       num_neighb=3,
+                                                       # Number of connections to other cities/intersections
+                                                       seed=215545,  # Random seed
+                                                       grid_mode=True,
+                                                       enhance_intersection=False
+                                                       ),
+                  schedule_generator=sparse_schedule_generator(speed_ration_map),
+                  number_of_agents=1,
+                  stochastic_data=stochastic_data,  # Malfunction data generator
+                  )
+
+    if rendering:
+        renderer = RenderTool(env)
+        renderer.render_env(show=True, frames=False, show_observations=False)
+    _action = dict()
+
+    replay_steps = [
+        Replay(
+            position=(28, 5),
+            direction=Grid4TransitionsEnum.EAST,
+            action=RailEnvActions.DO_NOTHING,
+            malfunction=3
+        ),
+        Replay(
+            position=(28, 5),
+            direction=Grid4TransitionsEnum.EAST,
+            action=RailEnvActions.DO_NOTHING,
+            malfunction=2
+        ),
+        # malfunction stops in the next step and we're still at the beginning of the cell
+        # --> if we take action DO_NOTHING, agent should restart without moving
+        #
+        Replay(
+            position=(28, 5),
+            direction=Grid4TransitionsEnum.EAST,
+            action=RailEnvActions.DO_NOTHING,
+            malfunction=1
+        ),
+        # we haven't started moving yet --> stay here
+        Replay(
+            position=(28, 5),
+            direction=Grid4TransitionsEnum.EAST,
+            action=RailEnvActions.DO_NOTHING,
+            malfunction=0
+        ),
+        # we start to move forward --> should go to next cell now
+        Replay(
+            position=(28, 5),
+            direction=Grid4TransitionsEnum.EAST,
+            action=RailEnvActions.MOVE_FORWARD,
+            malfunction=0
+        ),
+        Replay(
+            position=(28, 4),
+            direction=Grid4TransitionsEnum.WEST,
+            action=RailEnvActions.MOVE_FORWARD,
+            malfunction=0
+        )
+    ]
+
+    info_dict = {
+        'action_required': [True]
+    }
+
+    for i, replay in enumerate(replay_steps):
+
+        def _assert(actual, expected, msg):
+            assert actual == expected, "[{}] {}:  actual={}, expected={}".format(i, msg, actual, expected)
+
+        agent: EnvAgent = env.agents[0]
+
+        _assert(agent.position, replay.position, 'position')
+        _assert(agent.direction, replay.direction, 'direction')
+        _assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
+
+        if replay.action is not None:
+            assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
+            _, _, _, info_dict = env.step({0: replay.action})
+
+        else:
+            assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
+            _, _, _, info_dict = env.step({})
+
+        if rendering:
+            renderer.render_env(show=True, show_observations=True)
diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py
index 4b14a5cd..1cf0c325 100644
--- a/tests/test_multi_speed.py
+++ b/tests/test_multi_speed.py
@@ -580,8 +580,9 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True):
         _assert(agent.position, replay.position, 'position')
         _assert(agent.direction, replay.direction, 'direction')
 
-        if replay.malfunction:
-            agent.malfunction_data['malfunction'] = 2
+        if replay.malfunction > 0:
+            agent.malfunction_data['malfunction'] = replay.malfunction
+            agent.malfunction_data['moving_before_malfunction'] = agent.moving
 
         if replay.action is not None:
             assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
-- 
GitLab