diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py
index 855d1f5dffbef29da26ca1dead933af846b863bd..f75cb74537f03dc6fb1aecbadff37a183432e55a 100644
--- a/examples/custom_observation_example_03_ObservePredictions.py
+++ b/examples/custom_observation_example_03_ObservePredictions.py
@@ -54,11 +54,9 @@ class ObservePredictions(ObservationBuilder):
                 pos_list.append(self.predictions[a][t][1:3])
             # We transform (x,y) coodrinates to a single integer number for simpler comparison
             self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
-        observations = {}
 
-        # Collect all the different observation for all the agents
-        for h in handles:
-            observations[h] = self.get(h)
+        observations = super().get_many(handles)
+
         return observations
 
     def get(self, handle: int = 0) -> np.ndarray:
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index 3cc21966162dd28d183493a97cd6072a34abb738..2302fff9cdeeaf162b7ab38b1e67f5052c76b3ef 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -24,7 +24,7 @@ class ObservationBuilder:
         self.env = None
 
     def set_env(self, env: Environment):
-        self.env = env
+        self.env: Environment = env
 
     def reset(self):
         """
diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index f659ec8436a941606b6d649e24d2481e5be9b66d..d8c05c2020a9524ae3e6eab232a3e320bc699187 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -1,5 +1,6 @@
+from enum import IntEnum
 from itertools import starmap
-from typing import Tuple
+from typing import Tuple, Optional
 
 import numpy as np
 from attr import attrs, attrib, Factory
@@ -7,6 +8,13 @@ from attr import attrs, attrib, Factory
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 
 
+class RailAgentStatus(IntEnum):
+    READY_TO_DEPART = 0  # not in grid yet (position is None) -> prediction as if it were at initial position
+    ACTIVE = 1  # in grid (position is not None), not done -> prediction is remaining path
+    DONE = 2  # in grid (position is not None), but done -> prediction is stay at target forever
+    DONE_REMOVED = 3  # removed from grid (position is None) -> prediction is None
+
+
 @attrs
 class EnvAgentStatic(object):
     """ EnvAgentStatic - Stores initial position, direction and target.
@@ -14,7 +22,7 @@ class EnvAgentStatic(object):
         rather than where it is at the moment.
         The target should also be stored here.
     """
-    position = attrib(type=Tuple[int, int])
+    initial_position = attrib(type=Tuple[int, int])
     direction = attrib(type=Grid4TransitionsEnum)
     target = attrib(type=Tuple[int, int])
     moving = attrib(default=False, type=bool)
@@ -33,6 +41,9 @@ class EnvAgentStatic(object):
             lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0,
                           'moving_before_malfunction': False})))
 
+    status = attrib(default=RailAgentStatus.READY_TO_DEPART, type=RailAgentStatus)
+    position = attrib(default=None, type=Optional[Tuple[int, int]])
+
     @classmethod
     def from_lists(cls, positions, directions, targets, speeds=None, malfunction_rates=None):
         """ Create a list of EnvAgentStatics from lists of positions, directions and targets
@@ -65,7 +76,7 @@ class EnvAgentStatic(object):
 
         # I can't find an expression which works on both tuples, lists and ndarrays
         # which converts them all to a list of native python ints.
-        lPos = self.position
+        lPos = self.initial_position
         if type(lPos) is np.ndarray:
             lPos = lPos.tolist()
 
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index c13c4cb7c6e3926a522adbd0055e53132079a2c1..8adf94d3d5b01946aa42659c19b6060f0d593868 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -11,11 +11,20 @@ from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.env_prediction_builder import PredictionBuilder
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.grid.grid_utils import coordinate_to_position
+from flatland.envs.agent_utils import RailAgentStatus, EnvAgent
 from flatland.utils.ordered_set import OrderedSet
 
 
 class TreeObsForRailEnv(ObservationBuilder):
+    """
+    TreeObsForRailEnv object.
+
+    This object returns observation vectors for agents in the RailEnv environment.
+    The information is local to each agent and exploits the graph structure of the rail
+    network to simplify the representation of the state of the environment for each agent.
 
+    For details about the features in the tree observation see the get() function.
+    """
     Node = collections.namedtuple('Node', 'dist_own_target_encountered '
                                           'dist_other_target_encountered '
                                           'dist_other_agent_encountered '
@@ -27,19 +36,10 @@ class TreeObsForRailEnv(ObservationBuilder):
                                           'num_agents_opposite_direction '
                                           'num_agents_malfunctioning '
                                           'speed_min_fractional '
+                                          'num_agents_ready_to_depart '
                                           'childs')
 
-    tree_explorted_actions_char = ['L', 'F', 'R', 'B']
-
-    """
-    TreeObsForRailEnv object.
-
-    This object returns observation vectors for agents in the RailEnv environment.
-    The information is local to each agent and exploits the graph structure of the rail
-    network to simplify the representation of the state of the environment for each agent.
-
-    For details about the features in the tree observation see the get() function.
-    """
+    tree_explored_actions_char = ['L', 'F', 'R', 'B']
 
     def __init__(self, max_depth: int, predictor: PredictionBuilder = None):
         super().__init__()
@@ -67,19 +67,21 @@ class TreeObsForRailEnv(ObservationBuilder):
             self.predicted_dir = {}
             self.predictions = self.predictor.get()
             if self.predictions:
-
+                # TODO hacky hacky: `range(len(self.predictions[0]))` does not seem safe!!
                 for t in range(len(self.predictions[0])):
                     pos_list = []
                     dir_list = []
                     for a in handles:
+                        if self.predictions[a] is None:
+                            continue
                         pos_list.append(self.predictions[a][t][1:3])
                         dir_list.append(self.predictions[a][t][3])
                     self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
                     self.predicted_dir.update({t: dir_list})
                 self.max_prediction_depth = len(self.predicted_pos)
-        observations = {}
-        for h in handles:
-            observations[h] = self.get(h)
+
+        observations = super().get_many(handles)
+
         return observations
 
     def get(self, handle: int = 0) -> Node:
@@ -150,6 +152,8 @@ class TreeObsForRailEnv(ObservationBuilder):
             1 if no agent is observed
 
             min_fractional speed otherwise
+        #12:
+            number of agents ready to depart but no yet active
 
         Missing/padding nodes are filled in with -inf (truncated).
         Missing values in present node are filled in with +inf (truncated).
@@ -160,28 +164,41 @@ class TreeObsForRailEnv(ObservationBuilder):
         """
 
         # Update local lookup table for all agents' positions
-        self.location_has_agent = dict()
-        self.location_has_agent_direction = dict()
+        # ignore other agents not in the grid (only status active and done)
+        self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if
+                                   agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]}
+        self.location_has_agent_ready_to_depart = {}
         for agent in self.env.agents:
-            if tuple(agent.position) in self.location_has_agent:
-                self.location_has_agent[tuple(agent.position)] = self.location_has_agent[tuple(agent.position)] + 1
-            else:
-                self.location_has_agent[tuple(agent.position)] = 1
-
-            if (agent.position, agent.direction) in self.location_has_agent_direction:
-                self.location_has_agent_direction[(agent.position, agent.direction)] = \
-                self.location_has_agent_direction[(agent.position, agent.direction)] + 1
-            else:
-                self.location_has_agent_direction[(agent.position, agent.direction)] = 1
-
-        self.location_has_agent_speed = {tuple(agent.position): agent.speed_data['speed'] for agent in self.env.agents}
-        self.location_has_agent_malfunction = {tuple(agent.position): agent.malfunction_data['malfunction'] for agent in
-                                               self.env.agents}
+            if agent.status == RailAgentStatus.READY_TO_DEPART:
+                self.location_has_agent_ready_to_depart[tuple(agent.initial_position)] = \
+                    self.location_has_agent_ready_to_depart.get(tuple(agent.initial_position), 0) + 1
+        self.location_has_agent_direction = {
+            tuple(agent.position): agent.direction
+            for agent in self.env.agents if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]
+        }
+        self.location_has_agent_speed = {
+            tuple(agent.position): agent.speed_data['speed']
+            for agent in self.env.agents if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]
+        }
+        self.location_has_agent_malfunction = {
+            tuple(agent.position): agent.malfunction_data['malfunction']
+            for agent in self.env.agents if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]
+        }
 
         if handle > len(self.env.agents):
             print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
         agent = self.env.agents[handle]  # TODO: handle being treated as index
-        possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
+
+        if agent.status == RailAgentStatus.READY_TO_DEPART:
+            _agent_initial_position = agent.initial_position
+        elif agent.status == RailAgentStatus.ACTIVE:
+            _agent_initial_position = agent.position
+        elif agent.status == RailAgentStatus.DONE:
+            _agent_initial_position = agent.target
+        else:
+            return None
+
+        possible_transitions = self.env.rail.get_transitions(*_agent_initial_position, agent.direction)
         num_transitions = np.count_nonzero(possible_transitions)
 
         # Here information about the agent itself is stored
@@ -190,11 +207,13 @@ class TreeObsForRailEnv(ObservationBuilder):
         root_node_observation = TreeObsForRailEnv.Node(dist_own_target_encountered=0, dist_other_target_encountered=0,
                                                        dist_other_agent_encountered=0, dist_potential_conflict=0,
                                                        dist_unusable_switch=0, dist_to_next_branch=0,
-                                                       dist_min_to_target=distance_map[(handle, *agent.position,
-                                                                                        agent.direction)],
+                                                       dist_min_to_target=distance_map[
+                                                           (handle, *_agent_initial_position,
+                                                            agent.direction)],
                                                        num_agents_same_direction=0, num_agents_opposite_direction=0,
                                                        num_agents_malfunctioning=agent.malfunction_data['malfunction'],
                                                        speed_min_fractional=agent.speed_data['speed'],
+                                                       num_agents_ready_to_depart=0,
                                                        childs={})
 
         visited = OrderedSet()
@@ -210,16 +229,16 @@ class TreeObsForRailEnv(ObservationBuilder):
         for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]):
 
             if possible_transitions[branch_direction]:
-                new_cell = get_new_position(agent.position, branch_direction)
+                new_cell = get_new_position(_agent_initial_position, branch_direction)
 
                 branch_observation, branch_visited = \
                     self._explore_branch(handle, new_cell, branch_direction, 1, 1)
-                root_node_observation.childs[self.tree_explorted_actions_char[i]] = branch_observation
+                root_node_observation.childs[self.tree_explored_actions_char[i]] = branch_observation
 
                 visited |= branch_visited
             else:
                 # add cells filled with infinity if no transition is possible
-                root_node_observation.childs[self.tree_explorted_actions_char[i]] = -np.inf
+                root_node_observation.childs[self.tree_explored_actions_char[i]] = -np.inf
         self.env.dev_obs_dict[handle] = visited
 
         return root_node_observation
@@ -257,6 +276,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         malfunctioning_agent = 0
         min_fractional_speed = 1.
         num_steps = 1
+        other_agent_ready_to_depart_encountered = 0
         while exploring:
             # #############################
             # #############################
@@ -270,9 +290,11 @@ class TreeObsForRailEnv(ObservationBuilder):
                 if self.location_has_agent_malfunction[position] > malfunctioning_agent:
                     malfunctioning_agent = self.location_has_agent_malfunction[position]
 
-                if (position, direction) in self.location_has_agent_direction:
+                other_agent_ready_to_depart_encountered += self.location_has_agent_ready_to_depart.get(position, 0)
+
+                if self.location_has_agent_direction[position] == direction:
                     # Cummulate the number of agents on branch with same direction
-                    other_agent_same_direction += self.location_has_agent_direction[(position, direction)]
+                    other_agent_same_direction += self.location_has_agent_direction.get((position, direction), 0)
 
                     # Check fractional speed of agents
                     current_fractional_speed = self.location_has_agent_speed[position]
@@ -281,9 +303,9 @@ class TreeObsForRailEnv(ObservationBuilder):
 
                     # Other direction agents
                     # TODO: Test that this behavior is as expected
-                    other_agent_opposite_direction += self.location_has_agent[position] - \
-                                                      self.location_has_agent_direction[
-                                                          (position, direction)]
+                    other_agent_opposite_direction += \
+                        self.location_has_agent[position] - self.location_has_agent_direction.get((position, direction),
+                                                                                                  0)
 
                 else:
                     # If no agent in the same direction was found all agents in that position are other direction
@@ -314,7 +336,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                                 self._reverse_dir(
                                     self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
-                            if self.env.dones[ca] and tot_dist < potential_conflict:
+                            if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
 
                     # Look for conflicting paths at distance num_step-1
@@ -325,7 +347,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                                 and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \
                                 and tot_dist < potential_conflict:  # noqa: E125
                                 potential_conflict = tot_dist
-                            if self.env.dones[ca] and tot_dist < potential_conflict:
+                            if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
 
                     # Look for conflicting paths at distance num_step+1
@@ -336,7 +358,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                                 self.predicted_dir[post_step][ca])] == 1 \
                                 and tot_dist < potential_conflict:  # noqa: E125
                                 potential_conflict = tot_dist
-                            if self.env.dones[ca] and tot_dist < potential_conflict:
+                            if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
 
             if position in self.location_has_target and position != agent.target:
@@ -424,6 +446,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                                       num_agents_opposite_direction=other_agent_opposite_direction,
                                       num_agents_malfunctioning=malfunctioning_agent,
                                       speed_min_fractional=min_fractional_speed,
+                                      num_agents_ready_to_depart=other_agent_ready_to_depart_encountered,
                                       childs={})
 
         # #############################
@@ -443,7 +466,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                                                                           (branch_direction + 2) % 4,
                                                                           tot_dist + 1,
                                                                           depth + 1)
-                node.childs[self.tree_explorted_actions_char[i]] = branch_observation
+                node.childs[self.tree_explored_actions_char[i]] = branch_observation
                 if len(branch_visited) != 0:
                     visited |= branch_visited
             elif last_is_switch and possible_transitions[branch_direction]:
@@ -453,12 +476,12 @@ class TreeObsForRailEnv(ObservationBuilder):
                                                                           branch_direction,
                                                                           tot_dist + 1,
                                                                           depth + 1)
-                node.childs[self.tree_explorted_actions_char[i]] = branch_observation
+                node.childs[self.tree_explored_actions_char[i]] = branch_observation
                 if len(branch_visited) != 0:
                     visited |= branch_visited
             else:
                 # no exploring possible, add just cells with infinity
-                node.childs[self.tree_explorted_actions_char[i]] = -np.inf
+                node.childs[self.tree_explored_actions_char[i]] = -np.inf
 
         if depth == self.max_depth:
             node.childs.clear()
@@ -469,7 +492,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         Utility function to print tree observations returned by this object.
         """
         self.print_node_features(tree, "root", "")
-        for direction in self.tree_explorted_actions_char:
+        for direction in self.tree_explored_actions_char:
             self.print_subtree(tree.childs[direction], direction, "\t")
 
     @staticmethod
@@ -478,7 +501,8 @@ class TreeObsForRailEnv(ObservationBuilder):
               node.dist_other_target_encountered, ", ", node.dist_other_agent_encountered, ", ",
               node.dist_potential_conflict, ", ", node.dist_unusable_switch, ", ", node.dist_to_next_branch, ", ",
               node.dist_min_to_target, ", ", node.num_agents_same_direction, ", ", node.num_agents_opposite_direction,
-              ", ", node.num_agents_malfunctioning, ", ", node.speed_min_fractional)
+              ", ", node.num_agents_malfunctioning, ", ", node.speed_min_fractional, ", ",
+              node.num_agents_ready_to_depart)
 
     def print_subtree(self, node, label, indent):
         if node == -np.inf or not node:
@@ -490,7 +514,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         if not node.childs:
             return
 
-        for direction in self.tree_explorted_actions_char:
+        for direction in self.tree_explored_actions_char:
             self.print_subtree(node.childs[direction], direction, indent + "\t")
 
     def set_env(self, env: Environment):
@@ -510,15 +534,15 @@ class GlobalObsForRailEnv(ObservationBuilder):
         - transition map array with dimensions (env.height, env.width, 16),\
           assuming 16 bits encoding of transitions.
 
-        - A 3D array (map_height, map_width, 5) with
+        - obs_agents_state: A 3D array (map_height, map_width, 5) with
             - first channel containing the agents position and direction
-            - second channel containing the other agents positions and diretion
+            - second channel containing the other agents positions and direction
             - third channel containing agent/other agent malfunctions
             - fourth channel containing agent/other agent fractional speeds
-            ' fifth channel containing number of agents in cell (only larger then one at start position)
+            - fifth channel containing number of other agents ready to depart
 
-        - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\
-         target and the positions of the other agents targets.
+        - obs_targets: Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\
+         target and the positions of the other agents targets (flag only, no counter!).
     """
 
     def __init__(self):
@@ -537,20 +561,36 @@ class GlobalObsForRailEnv(ObservationBuilder):
 
     def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
 
+        agent = self.env.agents[handle]
+        if agent.status == RailAgentStatus.READY_TO_DEPART:
+            _agent_initial_position = agent.initial_position
+        elif agent.status == RailAgentStatus.ACTIVE:
+            _agent_initial_position = agent.position
+        elif agent.status == RailAgentStatus.DONE:
+            _agent_initial_position = agent.target
+        else:
+            return None
+
         obs_targets = np.zeros((self.env.height, self.env.width, 2))
         obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - 1
-        agent = self.env.agents[handle]
-        obs_agents_state[agent.position][0] = agent.direction
+
+        obs_agents_state[_agent_initial_position][0] = agent.direction
         obs_targets[agent.target][0] = 1
 
         for i in range(len(self.env.agents)):
-            other_agent = self.env.agents[i]
-            obs_agents_state[other_agent.position][4] += 1
-            if i != handle:
+            other_agent:EnvAgent = self.env.agents[i]
+
+            # ignore other agents not in the grid any more
+            if other_agent.status == RailAgentStatus.DONE_REMOVED:
+                continue
+
+            obs_targets[other_agent.target][1] = 1
+
+            # third to fifth channel only if different agent and in the grid
+            if i != handle and other_agent.position is not None:
                 obs_agents_state[other_agent.position][1] = other_agent.direction
-                obs_targets[other_agent.target][1] = 1
-            obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction']
-            obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed']
+                obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction']
+                obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed']
 
         return self.rail_obs, obs_agents_state, obs_targets
 
@@ -640,18 +680,14 @@ class LocalObsForRailEnv(ObservationBuilder):
         direction = np.identity(4)[agent.direction]
         return local_rail_obs, obs_map_state, obs_other_agents_state, direction
 
-    def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]:
+    def get_many(self, handles: Optional[List[int]] = None) -> Dict[
+        int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]:
         """
         Called whenever an observation has to be computed for the `env` environment, for each agent with handle
         in the `handles` list.
         """
 
-        observations = {}
-        if handles is None:
-            handles = []
-        for h in handles:
-            observations[h] = self.get(h)
-        return observations
+        return super().get_many(handles)
 
     def field_of_view(self, position, direction, state=None):
         # Compute the local field of view for an agent in the environment
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index 76095a2a2e1d9532951600118c6a777612641101..29b6947c28fa054cd1f4c44a204c740e4f536181 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -5,6 +5,7 @@ Collection of environment-specific PredictionBuilder.
 import numpy as np
 
 from flatland.core.env_prediction_builder import PredictionBuilder
+from flatland.envs.agent_utils import RailAgentStatus
 from flatland.envs.distance_map import DistanceMap
 from flatland.envs.rail_env import RailEnvActions
 from flatland.envs.rail_env_shortest_paths import get_shortest_paths
@@ -47,6 +48,9 @@ class DummyPredictorForRailEnv(PredictionBuilder):
         prediction_dict = {}
 
         for agent in agents:
+            if agent.status != RailAgentStatus.ACTIVE:
+                # TODO make this generic
+                continue
             action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
             _agent_initial_position = agent.position
             _agent_initial_direction = agent.direction
@@ -122,7 +126,17 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
 
         prediction_dict = {}
         for agent in agents:
-            _agent_initial_position = agent.position
+
+            if agent.status == RailAgentStatus.READY_TO_DEPART:
+                _agent_initial_position = agent.initial_position
+            elif agent.status == RailAgentStatus.ACTIVE:
+                _agent_initial_position = agent.position
+            elif agent.status == RailAgentStatus.DONE:
+                    _agent_initial_position = agent.target
+            else:
+                prediction_dict[agent.handle] = None
+                continue
+
             _agent_initial_direction = agent.direction
             agent_speed = agent.speed_data["speed"]
             times_per_cell = int(np.reciprocal(agent_speed))
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 98fa4d2e4b557edced94c6eb3692f4949abe9221..bf4eff2137b45bd87c33d6d35edb3d2d1ed2cb18 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -16,7 +16,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.grid.grid_utils import Vec2dOperations
 from flatland.core.transition_map import GridTransitionMap
-from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
+from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus
 from flatland.envs.distance_map import DistanceMap
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_generators import random_rail_generator, RailGenerator
@@ -235,12 +235,18 @@ class RailEnv(Environment):
         self.agents_static.append(agent_static)
         return len(self.agents_static) - 1
 
+    def set_agent_active(self, handle: int):
+        agent = self.agents[handle]
+        if agent.status == RailAgentStatus.READY_TO_DEPART and self.cell_free(agent.initial_position):
+            agent.status = RailAgentStatus.ACTIVE
+            agent.position = agent.initial_position
+
     def restart_agents(self):
         """ Reset the agents to their starting positions defined in agents_static
         """
         self.agents = EnvAgent.list_from_static(self.agents_static)
 
-    def reset(self, regen_rail=True, replace_agents=True):
+    def reset(self, regen_rail=True, replace_agents=True, activate_agents=False):
         """ if regen_rail then regenerate the rails.
             if replace_agents then regenerate the agents static.
             Relies on the rail_generator returning agent_static lists (pos, dir, target)
@@ -277,8 +283,13 @@ class RailEnv(Environment):
                 *self.schedule_generator(self.rail, self.get_num_agents(), agents_hints))
         self.restart_agents()
 
-        for i_agent in range(self.get_num_agents()):
-            agent = self.agents[i_agent]
+        if activate_agents:
+            for i_agent in range(self.get_num_agents()):
+                self.set_agent_active(i_agent)
+
+        for i_agent, agent in enumerate(self.agents):
+            if agent.status != RailAgentStatus.ACTIVE:
+                continue
 
             # A proportion of agent in the environment will receive a positive malfunction rate
             if np.random.random() < self.proportion_malfunctioning_trains:
@@ -366,7 +377,8 @@ class RailEnv(Environment):
             info_dict = {
                 'action_required': {i: False for i in range(self.get_num_agents())},
                 'malfunction': {i: 0 for i in range(self.get_num_agents())},
-                'speed': {i: 0 for i in range(self.get_num_agents())}
+                'speed': {i: 0 for i in range(self.get_num_agents())},
+                'status': {i: agent.status for i, agent in enumerate(self.agents)}
             }
             return self._get_observations(), self.rewards_dict, self.dones, info_dict
 
@@ -380,21 +392,19 @@ class RailEnv(Environment):
             self.rewards_dict = {i: self.global_reward for i in range(self.get_num_agents())}
         if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps):
             self.dones["__all__"] = True
-            for k in self.dones.keys():
-                self.dones[k] = True
-
-        action_required_agents = {
-            i: self.agents[i].speed_data['position_fraction'] == 0.0 for i in range(self.get_num_agents())
-        }
-        malfunction_agents = {
-            i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents())
-        }
-        speed_agents = {i: self.agents[i].speed_data['speed'] for i in range(self.get_num_agents())}
+            for i in range(self.get_num_agents()):
+                self.agents[i].status = RailAgentStatus.DONE
+                self.dones[i] = True
 
         info_dict = {
-            'action_required': action_required_agents,
-            'malfunction': malfunction_agents,
-            'speed': speed_agents
+            'action_required': {
+                i: (agent.status == RailAgentStatus.ACTIVE and agent.speed_data['position_fraction'] == 0.0)
+                for i, agent in enumerate(self.agents)},
+            'malfunction': {
+                i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents())
+            },
+            'speed': {i: self.agents[i].speed_data['speed'] for i in range(self.get_num_agents())},
+            'status': {i: agent.status for i, agent in enumerate(self.agents)}
         }
 
         return self._get_observations(), self.rewards_dict, self.dones, info_dict
@@ -412,10 +422,19 @@ class RailEnv(Environment):
         action_dict_ : Dict[int,RailEnvActions]
 
         """
-        if self.dones[i_agent]:  # this agent has already completed...
+        agent = self.agents[i_agent]
+        if agent.status == RailAgentStatus.DONE:  # this agent has already completed...
             return
 
-        agent = self.agents[i_agent]
+        # agent gets active by a MOVE_* action and if c
+        if agent.status == RailAgentStatus.READY_TO_DEPART:
+            if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT,
+                          RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position):
+                agent.status = RailAgentStatus.ACTIVE
+                agent.position = agent.initial_position
+            else:
+                return
+
         agent.old_direction = agent.direction
         agent.old_position = agent.position
 
@@ -508,6 +527,7 @@ class RailEnv(Environment):
 
             # has the agent reached its target?
             if np.equal(agent.position, agent.target).all():
+                agent.status = RailAgentStatus.DONE
                 self.dones[i_agent] = True
                 agent.moving = False
 
@@ -558,9 +578,15 @@ class RailEnv(Environment):
 
         # Check the new position is not the same as any of the existing agent positions
         # (including itself, for simplicity, since it is moving)
-        cell_free = not np.any(np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
+        cell_free = self.cell_free(new_position)
         return cell_free, new_cell_valid, new_direction, new_position, transition_valid
 
+    def cell_free(self, position):
+
+        agent_positions = [agent.position for agent in self.agents if agent.position is not None]
+        ret = len(agent_positions) == 0 or not np.any(np.equal(position, agent_positions).all(1))
+        return ret
+
     def check_action(self, agent: EnvAgent, action: RailEnvActions):
         """
 
diff --git a/flatland/envs/rail_env_shortest_paths.py b/flatland/envs/rail_env_shortest_paths.py
index 793601d4d18ac38b729d15883089d5acbfc41ed3..7944a49daf388403c8a226dbf866ac810d1286a3 100644
--- a/flatland/envs/rail_env_shortest_paths.py
+++ b/flatland/envs/rail_env_shortest_paths.py
@@ -7,6 +7,7 @@ import numpy as np
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.transition_map import GridTransitionMap
+from flatland.envs.agent_utils import RailAgentStatus
 from flatland.envs.distance_map import DistanceMap
 from flatland.envs.rail_env import RailEnvNextAction, RailEnvActions
 from flatland.utils.ordered_set import OrderedSet
@@ -92,7 +93,15 @@ def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = Non
     shortest_paths = dict()
 
     def _shortest_path_for_agent(agent):
-        position = agent.position
+        if agent.status == RailAgentStatus.READY_TO_DEPART:
+            position = agent.initial_position
+        elif agent.status == RailAgentStatus.ACTIVE:
+            position = agent.position
+        elif agent.status == RailAgentStatus.DONE:
+            position = agent.target
+        else:
+            shortest_paths[agent.handle] = None
+            return
         direction = agent.direction
         shortest_paths[agent.handle] = []
         distance = math.inf
diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py
index bdd4b48b356d69fc7afb5cb76bf12ee08880516d..05c6cc0da831a571c4f0b4f328ff6f3f4703cff1 100644
--- a/flatland/envs/schedule_generators.py
+++ b/flatland/envs/schedule_generators.py
@@ -223,7 +223,7 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator:
             agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3]) for d in data["agents_static"]]
 
         # setup with loaded data
-        agents_position = [a.position for a in agents_static]
+        agents_position = [a.initial_position for a in agents_static]
         agents_direction = [a.direction for a in agents_static]
         agents_target = [a.target for a in agents_static]
         if len(data['agents_static'][0]) > 5:
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index a26aa5ef3191e0b968c5cd1396f397d6cafd1dd9..c238319ad17bb09d9dbaea80335685ce1283feb1 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -147,6 +147,9 @@ class RenderTool(object):
         Plot a simple agent.
         Assumes a working graphics layer context (cf a MPL figure).
         """
+        if position_row_col is None:
+            return
+
         rt = self.__class__
 
         direction_row_col = rt.transitions_row_col[direction]  # agent direction in RC
@@ -537,7 +540,7 @@ class RenderTool(object):
 
         for agent_idx, agent in enumerate(self.env.agents):
 
-            if agent is None:
+            if agent is None or agent.position is None:
                 continue
 
             if self.agent_render_variant == AgentRenderVariant.BOX_ONLY:
diff --git a/tests/test_flaltland_rail_agent_status.py b/tests/test_flaltland_rail_agent_status.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4396c101d1879bf8f5f74f9a79f9a8541d45e31
--- /dev/null
+++ b/tests/test_flaltland_rail_agent_status.py
@@ -0,0 +1,124 @@
+import numpy as np
+
+from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.envs.agent_utils import RailAgentStatus
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv, RailEnvActions
+from flatland.envs.rail_generators import rail_from_grid_transition_map
+from flatland.envs.schedule_generators import random_schedule_generator
+from flatland.utils.simple_rail import make_simple_rail
+from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay
+
+np.random.seed(1)
+
+
+def test_initial_status():
+    """Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
+    rail, rail_map = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
+                  number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  )
+
+    set_penalties_for_replay(env)
+    test_config = ReplayConfig(
+        replay=[
+            Replay(
+                position=None,  # not entered grid yet
+                direction=Grid4TransitionsEnum.EAST,
+                status=RailAgentStatus.READY_TO_DEPART,
+                action=RailEnvActions.DO_NOTHING,
+                reward=0,
+
+            ),
+            Replay(
+                position=None,  # not entered grid yet before step
+                direction=Grid4TransitionsEnum.EAST,
+                status=RailAgentStatus.READY_TO_DEPART,
+                action=RailEnvActions.MOVE_LEFT,
+                reward=env.start_penalty + env.step_penalty * 0.5,  # auto-correction left to forward without penalty!
+            ),
+            Replay(
+                position=(3, 9),
+                direction=Grid4TransitionsEnum.EAST,
+                status=RailAgentStatus.ACTIVE,
+                action=None,
+                reward=env.step_penalty * 0.5,  # running at speed 0.5
+            ),
+            Replay(
+                position=(3, 8),
+                direction=Grid4TransitionsEnum.WEST,
+                status=RailAgentStatus.ACTIVE,
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.step_penalty * 0.5,  # running at speed 0.5
+            ),
+            Replay(
+                position=(3, 8),
+                direction=Grid4TransitionsEnum.WEST,
+                status=RailAgentStatus.ACTIVE,
+                action=None,
+                reward=env.step_penalty * 0.5,  # running at speed 0.5
+            ),
+            Replay(
+                position=(3, 7),
+                direction=Grid4TransitionsEnum.WEST,
+                status=RailAgentStatus.ACTIVE,
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.step_penalty * 0.5,  # running at speed 0.5
+
+            ),
+            Replay(
+                position=(3, 7),
+                direction=Grid4TransitionsEnum.WEST,
+                action=None,
+                reward=env.step_penalty * 0.5,  # running at speed 0.5
+                status=RailAgentStatus.ACTIVE
+            ),
+            Replay(
+                position=(3, 6),
+                direction=Grid4TransitionsEnum.WEST,
+                action=RailEnvActions.MOVE_RIGHT,
+                reward=env.step_penalty * 0.5,  # wrong action is corrected to forward without penalty!
+                status=RailAgentStatus.ACTIVE
+            ),
+            Replay(
+                position=(3, 6),
+                direction=Grid4TransitionsEnum.WEST,
+                action=None,
+                reward=env.global_reward,  # done
+                status=RailAgentStatus.ACTIVE
+            ),
+            Replay(
+                position=(3, 5),
+                direction=Grid4TransitionsEnum.WEST,
+                action=None,
+                reward=env.global_reward,  # already done
+                status=RailAgentStatus.DONE
+            ),
+            Replay(
+                position=(3, 5),
+                direction=Grid4TransitionsEnum.WEST,
+                action=None,
+                reward=env.global_reward,  # already done
+                status=RailAgentStatus.DONE
+            ),
+            Replay(
+                position=(3, 5),
+                direction=Grid4TransitionsEnum.WEST,
+                action=None,
+                reward=env.global_reward,  # already done
+                status=RailAgentStatus.DONE
+            )
+
+        ],
+        initial_position=(3, 9),  # east dead-end
+        initial_direction=Grid4TransitionsEnum.EAST,
+        target=(3, 5),
+        speed=0.5
+    )
+
+    run_replay_config(env, [test_config], activate_agents=False)
diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py
index 0d6d309765690b1f95c681d7d109a13071d7f86b..52b047244850a5b1c6b39dadd74568fc5f92deec 100644
--- a/tests/test_flatland_envs_observations.py
+++ b/tests/test_flatland_envs_observations.py
@@ -5,7 +5,7 @@ import numpy as np
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.agent_utils import EnvAgent
+from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
 from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv, RailEnvActions
@@ -29,6 +29,9 @@ def test_global_obs():
 
     global_obs = env.reset()
 
+    # we have to take step for the agent to enter the grid.
+    global_obs, _, _, _ = env.step({0: RailEnvActions.MOVE_FORWARD})
+
     assert (global_obs[0][0].shape == rail_map.shape + (16,))
 
     rail_map_recons = np.zeros_like(rail_map)
@@ -109,12 +112,14 @@ def test_reward_function_conflict(rendering=False):
     agent.direction = 0  # north
     agent.target = (3, 9)  # east dead-end
     agent.moving = True
+    agent.status = RailAgentStatus.ACTIVE
 
     agent = env.agents_static[1]
     agent.position = (3, 8)  # east dead-end
     agent.direction = 3  # west
     agent.target = (6, 6)  # south dead-end
     agent.moving = True
+    agent.status = RailAgentStatus.ACTIVE
 
     # reset to set agents from agents_static
     env.reset(False, False)
@@ -184,16 +189,20 @@ def test_reward_function_waiting(rendering=False):
 
     # set the initial position
     agent = env.agents_static[0]
+    agent.initial_position = (3, 8)  # east dead-end
     agent.position = (3, 8)  # east dead-end
     agent.direction = 3  # west
     agent.target = (3, 1)  # west dead-end
     agent.moving = True
+    agent.status = RailAgentStatus.ACTIVE
 
     agent = env.agents_static[1]
+    agent.initial_position = (5, 6)  # south dead-end
     agent.position = (5, 6)  # south dead-end
     agent.direction = 0  # north
     agent.target = (3, 8)  # east dead-end
     agent.moving = True
+    agent.status = RailAgentStatus.ACTIVE
 
     # reset to set agents from agents_static
     env.reset(False, False)
diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py
index f4ab68bc45a82b8f196fcfbebb26fd68a36c37a4..2517fb84de00c1346e5c601f903a9a1a677cfae4 100644
--- a/tests/test_flatland_envs_predictions.py
+++ b/tests/test_flatland_envs_predictions.py
@@ -5,6 +5,7 @@ import pprint
 import numpy as np
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.envs.agent_utils import RailAgentStatus
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv, RailEnvActions, RailEnvNextAction
@@ -31,12 +32,13 @@ def test_dummy_predictor(rendering=False):
     env.reset()
 
     # set initial position and direction for testing...
-    env.agents_static[0].position = (5, 6)
+    env.agents_static[0].initial_position = (5, 6)
     env.agents_static[0].direction = 0
     env.agents_static[0].target = (3, 0)
 
     # reset to set agents from agents_static
     env.reset(False, False)
+    env.set_agent_active(0)
 
     if rendering:
         renderer = RenderTool(env, gl="PILSVG")
@@ -124,10 +126,12 @@ def test_shortest_path_predictor(rendering=False):
 
     # set the initial position
     agent = env.agents_static[0]
+    agent.initial_position = (5, 6)  # south dead-end
     agent.position = (5, 6)  # south dead-end
     agent.direction = 0  # north
     agent.target = (3, 9)  # east dead-end
     agent.moving = True
+    agent.status = RailAgentStatus.ACTIVE
 
     # reset to set agents from agents_static
     env.reset(False, False)
@@ -139,9 +143,9 @@ def test_shortest_path_predictor(rendering=False):
 
     # compute the observations and predictions
     distance_map = env.distance_map.get()
-    assert distance_map[0, agent.position[0], agent.position[
-        1], agent.direction] == 5.0, "found {} instead of {}".format(
-        distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0)
+    assert distance_map[0, agent.initial_position[0], agent.initial_position[1], agent.direction] == 5.0, \
+        "found {} instead of {}".format(
+            distance_map[agent.handle, agent.initial_position[0], agent.position[1], agent.direction], 5.0)
 
     paths = get_shortest_paths(env.distance_map)[0]
     assert paths == [
@@ -259,19 +263,23 @@ def test_shortest_path_predictor_conflicts(rendering=False):
 
     # set the initial position
     agent = env.agents_static[0]
+    agent.initial_position = (5, 6)  # south dead-end
     agent.position = (5, 6)  # south dead-end
     agent.direction = 0  # north
     agent.target = (3, 9)  # east dead-end
     agent.moving = True
+    agent.status = RailAgentStatus.ACTIVE
 
     agent = env.agents_static[1]
+    agent.initial_position = (3, 8)  # east dead-end
     agent.position = (3, 8)  # east dead-end
     agent.direction = 3  # west
     agent.target = (6, 6)  # south dead-end
     agent.moving = True
+    agent.status = RailAgentStatus.ACTIVE
 
     # reset to set agents from agents_static
-    observations = env.reset(False, False)
+    observations = env.reset(False, False, True)
 
     if rendering:
         renderer = RenderTool(env, gl="PILSVG")
@@ -295,14 +303,14 @@ def test_shortest_path_predictor_conflicts(rendering=False):
 
 def _check_expected_conflicts(expected_conflicts, obs_builder, tree: TreeObsForRailEnv.Node, prompt=''):
     assert (tree.num_agents_opposite_direction > 0) == (() in expected_conflicts), "{}[]".format(prompt)
-    for a_1 in obs_builder.tree_explorted_actions_char:
+    for a_1 in obs_builder.tree_explored_actions_char:
         if tree.childs[a_1] == -np.inf:
             assert False == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
             continue
         else:
             conflict = tree.childs[a_1].num_agents_opposite_direction
             assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
-        for a_2 in obs_builder.tree_explorted_actions_char:
+        for a_2 in obs_builder.tree_explored_actions_char:
             if tree.childs[a_1].childs[a_2] == -np.inf:
                 assert False == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
             else:
diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py
index 0fefd3e212ddb5f084c1e219f4063079e03dabdf..e0281bb0d3f21a8e25fdff86ade5bed05f5dea13 100644
--- a/tests/test_flatland_envs_rail_env.py
+++ b/tests/test_flatland_envs_rail_env.py
@@ -85,7 +85,7 @@ def test_rail_environment_single_agent():
                        obs_builder_object=GlobalObsForRailEnv())
 
     for _ in range(200):
-        _ = rail_env.reset()
+        _ = rail_env.reset(False, False, True)
 
         # We do not care about target for the moment
         agent = rail_env.agents[0]
@@ -130,9 +130,6 @@ def test_rail_environment_single_agent():
                 done = dones['__all__']
 
 
-test_rail_environment_single_agent()
-
-
 def test_dead_end():
     transitions = RailEnvTransitions()
 
@@ -164,32 +161,12 @@ def test_dead_end():
                        number_of_agents=1,
                        obs_builder_object=GlobalObsForRailEnv())
 
-    def check_consistency(rail_env):
-        # We run step to check that trains do not move anymore
-        # after being done.
-        # TODO: GIACOMO: this is deprecated and should be updated; thenew behavior is that agents keep moving
-        # until they are manually stopped.
-        for i in range(7):
-            prev_pos = rail_env.agents[0].position
-
-            # The train cannot turn, so we check that when it tries,
-            # it stays where it is.
-            _ = rail_env.step({0: 1})
-            _ = rail_env.step({0: 3})
-            assert (rail_env.agents[0].position == prev_pos)
-            _, _, dones, _ = rail_env.step({0: 2})
-
-            if i < 5:
-                assert (not dones[0] and not dones['__all__'])
-            else:
-                assert (dones[0] and dones['__all__'])
-
     # We try the configuration in the 4 directions:
     rail_env.reset()
-    rail_env.agents = [EnvAgent(position=(0, 2), direction=1, target=(0, 0), moving=False)]
+    rail_env.agents = [EnvAgent(initial_position=(0, 2), direction=1, target=(0, 0), moving=False)]
 
     rail_env.reset()
-    rail_env.agents = [EnvAgent(position=(0, 2), direction=3, target=(0, 4), moving=False)]
+    rail_env.agents = [EnvAgent(initial_position=(0, 2), direction=3, target=(0, 4), moving=False)]
 
     # In the vertical configuration:
     rail_map = np.array(
@@ -210,10 +187,12 @@ def test_dead_end():
                        obs_builder_object=GlobalObsForRailEnv())
 
     rail_env.reset()
-    rail_env.agents = [EnvAgent(position=(2, 0), direction=2, target=(0, 0), moving=False)]
+    rail_env.agents = [EnvAgent(initial_position=(2, 0), direction=2, target=(0, 0), moving=False)]
 
     rail_env.reset()
-    rail_env.agents = [EnvAgent(position=(2, 0), direction=0, target=(4, 0), moving=False)]
+    rail_env.agents = [EnvAgent(initial_position=(2, 0), direction=0, target=(4, 0), moving=False)]
+
+    # TODO make assertions
 
 
 def test_get_entry_directions():
diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py
index 4600c4a3002995e1238a0ccbda762501ac985408..65d2d68c45155efda24536ecfd776bef5ebaab0c 100644
--- a/tests/test_flatland_envs_rail_env_shortest_paths.py
+++ b/tests/test_flatland_envs_rail_env_shortest_paths.py
@@ -25,6 +25,7 @@ def test_get_shortest_paths_unreachable():
     # set the initial position
     agent = env.agents_static[0]
     agent.position = (3, 1)  # west dead-end
+    agent.initial_position = (3, 1)  # west dead-end
     agent.direction = Grid4TransitionsEnum.WEST
     agent.target = (3, 9)  # east dead-end
     agent.moving = True
diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py
index 416efea5c909a5c89a9e6b8d782e18a1dd4bae62..ecba37641b740e0d2d71028b188b48f936730e74 100644
--- a/tests/test_flatland_envs_sparse_rail_generator.py
+++ b/tests/test_flatland_envs_sparse_rail_generator.py
@@ -21,6 +21,7 @@ def test_sparse_rail_generator():
                   schedule_generator=sparse_schedule_generator(),
                   number_of_agents=10,
                   obs_builder_object=GlobalObsForRailEnv())
+    env.reset(False, False, True)
     expected_grid_map = np.zeros((50, 50), dtype=env.rail.transitions.get_type())
     expected_grid_map[1][33] = 8192
     expected_grid_map[2][33] = 32800
@@ -1522,6 +1523,9 @@ def test_rail_env_action_required_info():
                                           obs_builder_object=GlobalObsForRailEnv())
     env_renderer = RenderTool(env_always_action, gl="PILSVG", )
 
+    env_always_action.reset(False, False, True)
+    env_only_if_action_required.reset(False, False, True)
+
     for step in range(100):
         print("step {}".format(step))
 
@@ -1575,6 +1579,7 @@ def test_rail_env_malfunction_speed_info():
                   number_of_agents=10,
                   obs_builder_object=GlobalObsForRailEnv(),
                   stochastic_data=stochastic_data)
+    env.reset(False, False, True)
 
     env_renderer = RenderTool(env, gl="PILSVG", )
     for step in range(100):
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index 1bcf9c5112a4423fd22e1102500890a39f853c74..9800a53735b2698fc8cf30d59fb85a5e631f5580 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -6,6 +6,7 @@ import numpy as np
 from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.grid.grid4_utils import get_new_position
+from flatland.envs.agent_utils import RailAgentStatus
 from flatland.envs.rail_env import RailEnv, RailEnvActions
 from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator
 from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator
@@ -29,7 +30,16 @@ class SingleAgentNavigationObs(ObservationBuilder):
     def get(self, handle: int = 0) -> List[int]:
         agent = self.env.agents[handle]
 
-        possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
+        if agent.status == RailAgentStatus.READY_TO_DEPART:
+            _agent_initial_position = agent.initial_position
+        elif agent.status == RailAgentStatus.ACTIVE:
+            _agent_initial_position = agent.position
+        elif agent.status == RailAgentStatus.DONE:
+            _agent_initial_position = agent.target
+        else:
+            return None
+
+        possible_transitions = self.env.rail.get_transitions(*_agent_initial_position, agent.direction)
         num_transitions = np.count_nonzero(possible_transitions)
 
         # Start from the current orientation, and see which transitions are available;
@@ -41,7 +51,7 @@ class SingleAgentNavigationObs(ObservationBuilder):
             min_distances = []
             for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                 if possible_transitions[direction]:
-                    new_position = get_new_position(agent.position, direction)
+                    new_position = get_new_position(_agent_initial_position, direction)
                     min_distances.append(
                         self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
                 else:
@@ -70,16 +80,19 @@ def test_malfunction_process():
                   obs_builder_object=SingleAgentNavigationObs(),
                   stochastic_data=stochastic_data)
 
-    obs = env.reset()
+    obs = env.reset(False, False, True)
 
     # Check that a initial duration for malfunction was assigned
     assert env.agents[0].malfunction_data['next_malfunction'] > 0
+    for agent in env.agents:
+        agent.status = RailAgentStatus.ACTIVE
 
     agent_halts = 0
     total_down_time = 0
     agent_old_position = env.agents[0].position
     for step in range(100):
         actions = {}
+
         for i in range(len(obs)):
             actions[i] = np.argmax(obs[i]) + 1
 
@@ -104,7 +117,8 @@ def test_malfunction_process():
         total_down_time += env.agents[0].malfunction_data['malfunction']
 
     # Check that the appropriate number of malfunctions is achieved
-    assert env.agents[0].malfunction_data['nr_malfunctions'] == 21
+    assert env.agents[0].malfunction_data['nr_malfunctions'] == 21, "Actual {}".format(
+        env.agents[0].malfunction_data['nr_malfunctions'])
 
     # Check that 20 stops where performed
     assert agent_halts == 20
@@ -120,8 +134,6 @@ def test_malfunction_process_statistically():
                        'malfunction_rate': 2,
                        'min_duration': 3,
                        'max_duration': 3}
-    np.random.seed(5)
-    random.seed(0)
 
     env = RailEnv(width=20,
                   height=20,
@@ -131,8 +143,9 @@ def test_malfunction_process_statistically():
                   number_of_agents=2,
                   obs_builder_object=SingleAgentNavigationObs(),
                   stochastic_data=stochastic_data)
-
-    env.reset()
+    np.random.seed(5)
+    random.seed(0)
+    env.reset(False, False, True)
     nb_malfunction = 0
     for step in range(100):
         action_dict: Dict[int, RailEnvActions] = {}
@@ -149,9 +162,6 @@ def test_malfunction_process_statistically():
 
 
 def test_initial_malfunction():
-    random.seed(0)
-    np.random.seed(0)
-
     stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                        'malfunction_rate': 70,  # Rate of malfunction occurence
                        'min_duration': 2,  # Minimal duration of malfunction
@@ -162,7 +172,8 @@ def test_initial_malfunction():
                         1. / 2.: 0.,  # Fast freight train
                         1. / 3.: 0.,  # Slow commuter train
                         1. / 4.: 0.}  # Slow freight train
-
+    np.random.seed(5)
+    random.seed(0)
     env = RailEnv(width=25,
                   height=30,
                   rail_generator=sparse_rail_generator(max_num_cities=5,
@@ -218,15 +229,15 @@ def test_initial_malfunction():
             )
         ],
         speed=env.agents[0].speed_data['speed'],
-        target=env.agents[0].target
+        target=env.agents[0].target,
+        initial_position=(28, 5),
+        initial_direction=Grid4TransitionsEnum.EAST,
     )
+
     run_replay_config(env, [replay_config])
 
 
 def test_initial_malfunction_stop_moving():
-    random.seed(0)
-    np.random.seed(0)
-
     stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                        'malfunction_rate': 70,  # Rate of malfunction occurence
                        'min_duration': 2,  # Minimal duration of malfunction
@@ -253,19 +264,21 @@ def test_initial_malfunction_stop_moving():
     replay_config = ReplayConfig(
         replay=[
             Replay(
-                position=(28, 5),
+                position=None,
                 direction=Grid4TransitionsEnum.EAST,
-                action=RailEnvActions.DO_NOTHING,
+                action=RailEnvActions.MOVE_FORWARD,
                 set_malfunction=3,
                 malfunction=3,
-                reward=env.step_penalty  # full step penalty when stopped
+                reward=env.step_penalty,  # full step penalty when stopped
+                status=RailAgentStatus.READY_TO_DEPART
             ),
             Replay(
                 position=(28, 5),
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.DO_NOTHING,
                 malfunction=2,
-                reward=env.step_penalty  # full step penalty when stopped
+                reward=env.step_penalty,  # full step penalty when stopped
+                status=RailAgentStatus.ACTIVE
             ),
             # malfunction stops in the next step and we're still at the beginning of the cell
             # --> if we take action STOP_MOVING, agent should restart without moving
@@ -275,7 +288,8 @@ def test_initial_malfunction_stop_moving():
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.STOP_MOVING,
                 malfunction=1,
-                reward=env.step_penalty  # full step penalty while stopped
+                reward=env.step_penalty,  # full step penalty while stopped
+                status=RailAgentStatus.ACTIVE
             ),
             # we have stopped and do nothing --> should stand still
             Replay(
@@ -283,7 +297,8 @@ def test_initial_malfunction_stop_moving():
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.DO_NOTHING,
                 malfunction=0,
-                reward=env.step_penalty  # full step penalty while stopped
+                reward=env.step_penalty,  # full step penalty while stopped
+                status=RailAgentStatus.ACTIVE
             ),
             # we start to move forward --> should go to next cell now
             Replay(
@@ -291,21 +306,24 @@ def test_initial_malfunction_stop_moving():
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.MOVE_FORWARD,
                 malfunction=0,
-                reward=env.start_penalty + env.step_penalty * 1.0  # full step penalty while stopped
+                reward=env.start_penalty + env.step_penalty * 1.0,  # full step penalty while stopped
+                status=RailAgentStatus.ACTIVE
             ),
             Replay(
                 position=(28, 6),
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.MOVE_FORWARD,
                 malfunction=0,
-                reward=env.step_penalty * 1.0  # full step penalty while stopped
+                reward=env.step_penalty * 1.0,  # full step penalty while stopped
+                status=RailAgentStatus.ACTIVE
             )
         ],
         speed=env.agents[0].speed_data['speed'],
-        target=env.agents[0].target
+        target=env.agents[0].target,
+        initial_position=(28, 5),
+        initial_direction=Grid4TransitionsEnum.EAST,
     )
-
-    run_replay_config(env, [replay_config])
+    run_replay_config(env, [replay_config], activate_agents=False)
 
 
 def test_initial_malfunction_do_nothing():
@@ -336,20 +354,23 @@ def test_initial_malfunction_do_nothing():
                   )
     set_penalties_for_replay(env)
     replay_config = ReplayConfig(
-        replay=[Replay(
-            position=(28, 5),
-            direction=Grid4TransitionsEnum.EAST,
-            action=RailEnvActions.DO_NOTHING,
-            set_malfunction=3,
-            malfunction=3,
-            reward=env.step_penalty  # full step penalty while malfunctioning
-        ),
+        replay=[
+            Replay(
+                position=None,
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.MOVE_FORWARD,
+                set_malfunction=3,
+                malfunction=3,
+                reward=env.step_penalty,  # full step penalty while malfunctioning
+                status=RailAgentStatus.READY_TO_DEPART
+            ),
             Replay(
                 position=(28, 5),
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.DO_NOTHING,
                 malfunction=2,
-                reward=env.step_penalty  # full step penalty while malfunctioning
+                reward=env.step_penalty,  # full step penalty while malfunctioning
+                status=RailAgentStatus.ACTIVE
             ),
             # malfunction stops in the next step and we're still at the beginning of the cell
             # --> if we take action DO_NOTHING, agent should restart without moving
@@ -359,7 +380,8 @@ def test_initial_malfunction_do_nothing():
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.DO_NOTHING,
                 malfunction=1,
-                reward=env.step_penalty  # full step penalty while stopped
+                reward=env.step_penalty,  # full step penalty while stopped
+                status=RailAgentStatus.ACTIVE
             ),
             # we haven't started moving yet --> stay here
             Replay(
@@ -367,7 +389,8 @@ def test_initial_malfunction_do_nothing():
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.DO_NOTHING,
                 malfunction=0,
-                reward=env.step_penalty  # full step penalty while stopped
+                reward=env.step_penalty,  # full step penalty while stopped
+                status=RailAgentStatus.ACTIVE
             ),
             # we start to move forward --> should go to next cell now
             Replay(
@@ -375,21 +398,25 @@ def test_initial_malfunction_do_nothing():
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.MOVE_FORWARD,
                 malfunction=0,
-                reward=env.start_penalty + env.step_penalty * 1.0  # start penalty + step penalty for speed 1.0
+                reward=env.start_penalty + env.step_penalty * 1.0,  # start penalty + step penalty for speed 1.0
+                status=RailAgentStatus.ACTIVE
             ),
             Replay(
                 position=(28, 6),
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.MOVE_FORWARD,
                 malfunction=0,
-                reward=env.step_penalty * 1.0  # step penalty for speed 1.0
+                reward=env.step_penalty * 1.0,  # step penalty for speed 1.0
+                status=RailAgentStatus.ACTIVE
             )
         ],
         speed=env.agents[0].speed_data['speed'],
-        target=env.agents[0].target
+        target=env.agents[0].target,
+        initial_position=(28, 5),
+        initial_direction=Grid4TransitionsEnum.EAST,
     )
 
-    run_replay_config(env, [replay_config])
+    run_replay_config(env, [replay_config], activate_agents=False)
 
 
 def test_initial_nextmalfunction_not_below_zero():
diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py
index 9035062bad47dbdf432c6cd980c9d2361d42de6e..393d0e0087d3068050759575684d749d7224bc6b 100644
--- a/tests/test_global_observation.py
+++ b/tests/test_global_observation.py
@@ -1,7 +1,8 @@
 import numpy as np
 
+from flatland.envs.agent_utils import EnvAgent
 from flatland.envs.observations import GlobalObsForRailEnv
-from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_env import RailEnv, RailEnvActions
 from flatland.envs.rail_generators import sparse_rail_generator
 from flatland.envs.schedule_generators import sparse_schedule_generator
 
@@ -31,24 +32,27 @@ def test_get_global_observation():
                   schedule_generator=sparse_schedule_generator(speed_ration_map),
                   number_of_agents=number_of_agents, stochastic_data=stochastic_data,  # Malfunction data generator
                   obs_builder_object=GlobalObsForRailEnv())
-    obs, all_rewards, done, _ = env.step({0: 0})
 
+    obs, all_rewards, done, _ = env.step({i: RailEnvActions.MOVE_FORWARD for i in range(number_of_agents)})
     for i in range(len(env.agents)):
+        agent: EnvAgent = env.agents[i]
+        print("[{}] status={}, position={}, target={}, initial_position={}".format(i, agent.status, agent.position,
+                                                                                   agent.target,
+                                                                                   agent.initial_position))
+
+    for i, agent in enumerate(env.agents):
         obs_agents_state = obs[i][1]
         obs_targets = obs[i][2]
 
+        #
         nr_agents = np.count_nonzero(obs_targets[:, :, 0])
-        nr_agents_other = np.count_nonzero(obs_targets[:, :, 1])
         assert nr_agents == 1
-        assert nr_agents_other == (number_of_agents - 1)
-
-        # since the array is initialized with -1 add one in order to used np.count_nonzero
-        obs_agents_state += 1
-        obs_agents_state_0 = np.count_nonzero(obs_agents_state[:, :, 0])
-        obs_agents_state_1 = np.count_nonzero(obs_agents_state[:, :, 1])
-        obs_agents_state_2 = np.count_nonzero(obs_agents_state[:, :, 2])
-        obs_agents_state_3 = np.count_nonzero(obs_agents_state[:, :, 3])
-        assert obs_agents_state_0 == 1
-        assert obs_agents_state_1 == (number_of_agents - 1)
-        assert obs_agents_state_2 == number_of_agents
-        assert obs_agents_state_3 == number_of_agents
+
+        for r in range(env.height):
+            for c in range(env.width):
+                _other_agent_target = 0
+                for other_i, other_agent in enumerate(env.agents):
+                    if other_agent.target == (r, c):
+                        _other_agent_target = 1
+                        break
+                assert obs_targets[(r, c)][1] == _other_agent_target, "agent {} at {} expected {}".format(i, (r,c), _other_agent_target)
diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py
index b0f274ba4c4b5453140fcc50bc6137e39e8e4f04..3cd0a4c1d9812c728089255eb1150111b783a463 100644
--- a/tests/test_multi_speed.py
+++ b/tests/test_multi_speed.py
@@ -63,7 +63,8 @@ def test_multi_speed_init():
 
     # Set all the different speeds
     # Reset environment and get initial observations for all agents
-    env.reset()
+    env.reset(False, False, True)
+
     # Here you can also further enhance the provided observation by means of normalization
     # See training navigation example in the baseline repository
     old_pos = []
@@ -188,7 +189,9 @@ def test_multispeed_actions_no_malfunction_no_blocking():
             ),
         ],
         target=(3, 0),  # west dead-end
-        speed=0.5
+        speed=0.5,
+        initial_position=(3, 9),  # east dead-end
+        initial_direction=Grid4TransitionsEnum.EAST,
     )
 
     run_replay_config(env, [test_config])
@@ -285,7 +288,10 @@ def test_multispeed_actions_no_malfunction_blocking():
                 )
             ],
             target=(3, 0),  # west dead-end
-            speed=1 / 3),
+            speed=1 / 3,
+            initial_position=(3, 8),
+            initial_direction=Grid4TransitionsEnum.WEST,
+        ),
         ReplayConfig(
             replay=[
                 Replay(
@@ -369,7 +375,9 @@ def test_multispeed_actions_no_malfunction_blocking():
                 ),
             ],
             target=(3, 0),  # west dead-end
-            speed=0.5
+            speed=0.5,
+            initial_position=(3, 9),  # east dead-end
+            initial_direction=Grid4TransitionsEnum.EAST,
         )
 
     ]
@@ -505,7 +513,9 @@ def test_multispeed_actions_malfunction_no_blocking():
 
         ],
         target=(3, 0),  # west dead-end
-        speed=0.5
+        speed=0.5,
+        initial_position=(3, 9),  # east dead-end
+        initial_direction=Grid4TransitionsEnum.EAST,
     )
     run_replay_config(env, [test_config])
 
@@ -587,7 +597,9 @@ def test_multispeed_actions_no_malfunction_invalid_actions():
 
         ],
         target=(3, 0),  # west dead-end
-        speed=0.5
+        speed=0.5,
+        initial_position=(3, 9),  # east dead-end
+        initial_direction=Grid4TransitionsEnum.EAST,
     )
 
     run_replay_config(env, [test_config])
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 903120d868aa65833e7c2393ddfcc821c26da4f6..f5d1cd5c957d1e81ca2f2465fb8265c6b2795342 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -5,7 +5,7 @@ import numpy as np
 from attr import attrs, attrib
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
-from flatland.envs.agent_utils import EnvAgent
+from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
 from flatland.envs.rail_env import RailEnvActions, RailEnv
 from flatland.utils.rendertools import RenderTool
 
@@ -18,6 +18,7 @@ class Replay(object):
     malfunction = attrib(default=0, type=int)
     set_malfunction = attrib(default=None, type=Optional[int])
     reward = attrib(default=None, type=Optional[float])
+    status = attrib(default=None, type=Optional[RailAgentStatus])
 
 
 @attrs
@@ -25,6 +26,8 @@ class ReplayConfig(object):
     replay = attrib(type=List[Replay])
     target = attrib(type=Tuple[int, int])
     speed = attrib(type=float)
+    initial_position = attrib(type=Tuple[int, int])
+    initial_direction = attrib(type=Grid4TransitionsEnum)
 
 
 # ensure that env is working correctly with start/stop/invalidaction penalty different from 0
@@ -35,7 +38,7 @@ def set_penalties_for_replay(env: RailEnv):
     env.invalid_action_penalty = -29
 
 
-def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False):
+def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False, activate_agents=True):
     """
     Runs the replay configs and checks assertions.
 
@@ -47,10 +50,12 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
     - position, direction before step are verified
     - optionally, set_malfunction is applied
     - malfunction is verified
+    - status is verified (optionally)
 
     *After each step*
     - reward is verified after step
 
+
     Parameters
     ----------
     env
@@ -67,18 +72,20 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
     for step in range(len(test_configs[0].replay)):
         if step == 0:
             for a, test_config in enumerate(test_configs):
-                agent: EnvAgent = env.agents[a]
-                replay = test_config.replay[0]
+                agent: EnvAgent = env.agents_static[a]
                 # set the initial position
-                agent.position = replay.position
-                agent.direction = replay.direction
+                agent.initial_position = test_config.initial_position
+                agent.direction = test_config.initial_direction
                 agent.target = test_config.target
                 agent.speed_data['speed'] = test_config.speed
+            env.reset(False, False, activate_agents)
 
         def _assert(a, actual, expected, msg):
-            assert np.allclose(actual, expected), "[{}] agent {} {}:  actual={}, expected={}".format(step, a, msg,
-                                                                                                    actual,
-                                                                                                    expected)
+            print("[{}] verifying {} on agent {}: actual={}, expected={}".format(step, msg, a, actual, expected))
+            assert (actual == expected) or (
+                np.allclose(actual, expected)), "[{}] agent {} {}:  actual={}, expected={}".format(step, a, msg,
+                                                                                                   actual,
+                                                                                                   expected)
 
         action_dict = {}
 
@@ -88,26 +95,29 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
 
             _assert(a, agent.position, replay.position, 'position')
             _assert(a, agent.direction, replay.direction, 'direction')
+            if replay.status is not None:
+                _assert(a, agent.status, replay.status, 'status')
 
             if replay.action is not None:
-                assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format(
+                assert info_dict['action_required'][
+                           a] == True or agent.status == RailAgentStatus.READY_TO_DEPART, "[{}] agent {} expecting action_required={} or agent status READY_TO_DEPART".format(
                     step, a, True)
                 action_dict[a] = replay.action
             else:
-                assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format(
-                    step, a, False)
+                assert info_dict['action_required'][
+                           a] == False, "[{}] agent {} expecting action_required={}, but found {}".format(
+                    step, a, False, info_dict['action_required'][a])
 
             if replay.set_malfunction is not None:
                 agent.malfunction_data['malfunction'] = replay.set_malfunction
                 agent.malfunction_data['moving_before_malfunction'] = agent.moving
             _assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
-
+        print(step)
         _, rewards_dict, _, info_dict = env.step(action_dict)
         if rendering:
             renderer.render_env(show=True, show_observations=True)
 
         for a, test_config in enumerate(test_configs):
             replay = test_config.replay[step]
-            _assert(a, rewards_dict[a], replay.reward, 'reward')
-
 
+            _assert(a, rewards_dict[a], replay.reward, 'reward')
diff --git a/tests/tests_generators.py b/tests/tests_generators.py
index 4c925789e6560077d637e2a594c736df8850d00a..9bfe3a47d0871cbfbbabb4e35eac0bd6ecaaedf0 100644
--- a/tests/tests_generators.py
+++ b/tests/tests_generators.py
@@ -102,6 +102,7 @@ def test_rail_from_grid_transition_map():
                   schedule_generator=random_schedule_generator(),
                   number_of_agents=n_agents
                   )
+    env.reset(False, False, True)
     nr_rail_elements = np.count_nonzero(env.rail.grid)
 
     # Check if the number of non-empty rail cells is ok