diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py
index 855d1f5dffbef29da26ca1dead933af846b863bd..f75cb74537f03dc6fb1aecbadff37a183432e55a 100644
--- a/examples/custom_observation_example_03_ObservePredictions.py
+++ b/examples/custom_observation_example_03_ObservePredictions.py
@@ -54,11 +54,9 @@ class ObservePredictions(ObservationBuilder):
                 pos_list.append(self.predictions[a][t][1:3])
             # We transform (x,y) coodrinates to a single integer number for simpler comparison
             self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
-        observations = {}
 
-        # Collect all the different observation for all the agents
-        for h in handles:
-            observations[h] = self.get(h)
+        observations = super().get_many(handles)
+
         return observations
 
     def get(self, handle: int = 0) -> np.ndarray:
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index 3cc21966162dd28d183493a97cd6072a34abb738..2302fff9cdeeaf162b7ab38b1e67f5052c76b3ef 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -24,7 +24,7 @@ class ObservationBuilder:
         self.env = None
 
     def set_env(self, env: Environment):
-        self.env = env
+        self.env: Environment = env
 
     def reset(self):
         """
diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index f659ec8436a941606b6d649e24d2481e5be9b66d..d8c05c2020a9524ae3e6eab232a3e320bc699187 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -1,5 +1,6 @@
+from enum import IntEnum
 from itertools import starmap
-from typing import Tuple
+from typing import Tuple, Optional
 
 import numpy as np
 from attr import attrs, attrib, Factory
@@ -7,6 +8,13 @@ from attr import attrs, attrib, Factory
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 
 
+class RailAgentStatus(IntEnum):
+    READY_TO_DEPART = 0  # not in grid yet (position is None) -> prediction as if it were at initial position
+    ACTIVE = 1  # in grid (position is not None), not done -> prediction is remaining path
+    DONE = 2  # in grid (position is not None), but done -> prediction is stay at target forever
+    DONE_REMOVED = 3  # removed from grid (position is None) -> prediction is None
+
+
 @attrs
 class EnvAgentStatic(object):
     """ EnvAgentStatic - Stores initial position, direction and target.
@@ -14,7 +22,7 @@ class EnvAgentStatic(object):
         rather than where it is at the moment.
         The target should also be stored here.
     """
-    position = attrib(type=Tuple[int, int])
+    initial_position = attrib(type=Tuple[int, int])
     direction = attrib(type=Grid4TransitionsEnum)
     target = attrib(type=Tuple[int, int])
     moving = attrib(default=False, type=bool)
@@ -33,6 +41,9 @@ class EnvAgentStatic(object):
             lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0,
                           'moving_before_malfunction': False})))
 
+    status = attrib(default=RailAgentStatus.READY_TO_DEPART, type=RailAgentStatus)
+    position = attrib(default=None, type=Optional[Tuple[int, int]])
+
     @classmethod
     def from_lists(cls, positions, directions, targets, speeds=None, malfunction_rates=None):
         """ Create a list of EnvAgentStatics from lists of positions, directions and targets
@@ -65,7 +76,7 @@ class EnvAgentStatic(object):
 
         # I can't find an expression which works on both tuples, lists and ndarrays
         # which converts them all to a list of native python ints.
-        lPos = self.position
+        lPos = self.initial_position
         if type(lPos) is np.ndarray:
             lPos = lPos.tolist()
 
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index c23d4345a03c761ad4c4ac1d936db817f8acc529..5cc6a8c11c2b7d5045d66a820f312dc0fd61a492 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -11,11 +11,20 @@ from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.env_prediction_builder import PredictionBuilder
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.grid.grid_utils import coordinate_to_position
+from flatland.envs.agent_utils import RailAgentStatus, EnvAgent
 from flatland.utils.ordered_set import OrderedSet
 
 
 class TreeObsForRailEnv(ObservationBuilder):
+    """
+    TreeObsForRailEnv object.
 
+    This object returns observation vectors for agents in the RailEnv environment.
+    The information is local to each agent and exploits the graph structure of the rail
+    network to simplify the representation of the state of the environment for each agent.
+
+    For details about the features in the tree observation see the get() function.
+    """
     Node = collections.namedtuple('Node', 'dist_own_target_encountered '
                                           'dist_other_target_encountered '
                                           'dist_other_agent_encountered '
@@ -27,19 +36,10 @@ class TreeObsForRailEnv(ObservationBuilder):
                                           'num_agents_opposite_direction '
                                           'num_agents_malfunctioning '
                                           'speed_min_fractional '
+                                          'num_agents_ready_to_depart '
                                           'childs')
 
-    tree_explorted_actions_char = ['L', 'F', 'R', 'B']
-
-    """
-    TreeObsForRailEnv object.
-
-    This object returns observation vectors for agents in the RailEnv environment.
-    The information is local to each agent and exploits the graph structure of the rail
-    network to simplify the representation of the state of the environment for each agent.
-
-    For details about the features in the tree observation see the get() function.
-    """
+    tree_explored_actions_char = ['L', 'F', 'R', 'B']
 
     def __init__(self, max_depth: int, predictor: PredictionBuilder = None):
         super().__init__()
@@ -67,19 +67,21 @@ class TreeObsForRailEnv(ObservationBuilder):
             self.predicted_dir = {}
             self.predictions = self.predictor.get()
             if self.predictions:
-
+                # TODO hacky hacky: `range(len(self.predictions[0]))` does not seem safe!!
                 for t in range(len(self.predictions[0])):
                     pos_list = []
                     dir_list = []
                     for a in handles:
+                        if self.predictions[a] is None:
+                            continue
                         pos_list.append(self.predictions[a][t][1:3])
                         dir_list.append(self.predictions[a][t][3])
                     self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
                     self.predicted_dir.update({t: dir_list})
                 self.max_prediction_depth = len(self.predicted_pos)
-        observations = {}
-        for h in handles:
-            observations[h] = self.get(h)
+
+        observations = super().get_many(handles)
+
         return observations
 
     def get(self, handle: int = 0) -> Node:
@@ -150,6 +152,8 @@ class TreeObsForRailEnv(ObservationBuilder):
             1 if no agent is observed
 
             min_fractional speed otherwise
+        #12:
+            number of agents ready to depart but no yet active
 
         Missing/padding nodes are filled in with -inf (truncated).
         Missing values in present node are filled in with +inf (truncated).
@@ -160,16 +164,41 @@ class TreeObsForRailEnv(ObservationBuilder):
         """
 
         # Update local lookup table for all agents' positions
-        self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents}
-        self.location_has_agent_direction = {tuple(agent.position): agent.direction for agent in self.env.agents}
-        self.location_has_agent_speed = {tuple(agent.position): agent.speed_data['speed'] for agent in self.env.agents}
-        self.location_has_agent_malfunction = {tuple(agent.position): agent.malfunction_data['malfunction'] for agent in
-                                               self.env.agents}
+        # ignore other agents not in the grid (only status active and done)
+        self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if
+                                   agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]}
+        self.location_has_agent_ready_to_depart = {}
+        for agent in self.env.agents:
+            if agent.status == RailAgentStatus.READY_TO_DEPART:
+                self.location_has_agent_ready_to_depart[tuple(agent.initial_position)] = \
+                    self.location_has_agent_ready_to_depart.get(tuple(agent.initial_position), 0) + 1
+        self.location_has_agent_direction = {
+            tuple(agent.position): agent.direction
+            for agent in self.env.agents if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]
+        }
+        self.location_has_agent_speed = {
+            tuple(agent.position): agent.speed_data['speed']
+            for agent in self.env.agents if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]
+        }
+        self.location_has_agent_malfunction = {
+            tuple(agent.position): agent.malfunction_data['malfunction']
+            for agent in self.env.agents if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]
+        }
 
         if handle > len(self.env.agents):
             print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
         agent = self.env.agents[handle]  # TODO: handle being treated as index
-        possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
+
+        if agent.status == RailAgentStatus.READY_TO_DEPART:
+            _agent_initial_position = agent.initial_position
+        elif agent.status == RailAgentStatus.ACTIVE:
+            _agent_initial_position = agent.position
+        elif agent.status == RailAgentStatus.DONE:
+            _agent_initial_position = agent.target
+        else:
+            return None
+
+        possible_transitions = self.env.rail.get_transitions(*_agent_initial_position, agent.direction)
         num_transitions = np.count_nonzero(possible_transitions)
 
         # Here information about the agent itself is stored
@@ -178,11 +207,13 @@ class TreeObsForRailEnv(ObservationBuilder):
         root_node_observation = TreeObsForRailEnv.Node(dist_own_target_encountered=0, dist_other_target_encountered=0,
                                                        dist_other_agent_encountered=0, dist_potential_conflict=0,
                                                        dist_unusable_switch=0, dist_to_next_branch=0,
-                                                       dist_min_to_target=distance_map[(handle, *agent.position,
-                                                                                        agent.direction)],
+                                                       dist_min_to_target=distance_map[
+                                                           (handle, *_agent_initial_position,
+                                                            agent.direction)],
                                                        num_agents_same_direction=0, num_agents_opposite_direction=0,
                                                        num_agents_malfunctioning=agent.malfunction_data['malfunction'],
                                                        speed_min_fractional=agent.speed_data['speed'],
+                                                       num_agents_ready_to_depart=0,
                                                        childs={})
 
         visited = OrderedSet()
@@ -198,16 +229,16 @@ class TreeObsForRailEnv(ObservationBuilder):
         for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]):
 
             if possible_transitions[branch_direction]:
-                new_cell = get_new_position(agent.position, branch_direction)
+                new_cell = get_new_position(_agent_initial_position, branch_direction)
 
                 branch_observation, branch_visited = \
                     self._explore_branch(handle, new_cell, branch_direction, 1, 1)
-                root_node_observation.childs[self.tree_explorted_actions_char[i]] = branch_observation
+                root_node_observation.childs[self.tree_explored_actions_char[i]] = branch_observation
 
                 visited |= branch_visited
             else:
                 # add cells filled with infinity if no transition is possible
-                root_node_observation.childs[self.tree_explorted_actions_char[i]] = -np.inf
+                root_node_observation.childs[self.tree_explored_actions_char[i]] = -np.inf
         self.env.dev_obs_dict[handle] = visited
 
         return root_node_observation
@@ -245,6 +276,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         malfunctioning_agent = 0
         min_fractional_speed = 1.
         num_steps = 1
+        other_agent_ready_to_depart_encountered = 0
         while exploring:
             # #############################
             # #############################
@@ -258,6 +290,8 @@ class TreeObsForRailEnv(ObservationBuilder):
                 if self.location_has_agent_malfunction[position] > malfunctioning_agent:
                     malfunctioning_agent = self.location_has_agent_malfunction[position]
 
+                other_agent_ready_to_depart_encountered += self.location_has_agent_ready_to_depart.get(position, 0)
+
                 if self.location_has_agent_direction[position] == direction:
                     # Cummulate the number of agents on branch with same direction
                     other_agent_same_direction += 1
@@ -296,7 +330,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                                 self._reverse_dir(
                                     self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
-                            if self.env.dones[ca] and tot_dist < potential_conflict:
+                            if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
 
                     # Look for conflicting paths at distance num_step-1
@@ -307,7 +341,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                                 and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \
                                 and tot_dist < potential_conflict:  # noqa: E125
                                 potential_conflict = tot_dist
-                            if self.env.dones[ca] and tot_dist < potential_conflict:
+                            if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
 
                     # Look for conflicting paths at distance num_step+1
@@ -318,7 +352,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                                 self.predicted_dir[post_step][ca])] == 1 \
                                 and tot_dist < potential_conflict:  # noqa: E125
                                 potential_conflict = tot_dist
-                            if self.env.dones[ca] and tot_dist < potential_conflict:
+                            if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
 
             if position in self.location_has_target and position != agent.target:
@@ -406,6 +440,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                                       num_agents_opposite_direction=other_agent_opposite_direction,
                                       num_agents_malfunctioning=malfunctioning_agent,
                                       speed_min_fractional=min_fractional_speed,
+                                      num_agents_ready_to_depart=other_agent_ready_to_depart_encountered,
                                       childs={})
 
         # #############################
@@ -425,7 +460,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                                                                           (branch_direction + 2) % 4,
                                                                           tot_dist + 1,
                                                                           depth + 1)
-                node.childs[self.tree_explorted_actions_char[i]] = branch_observation
+                node.childs[self.tree_explored_actions_char[i]] = branch_observation
                 if len(branch_visited) != 0:
                     visited |= branch_visited
             elif last_is_switch and possible_transitions[branch_direction]:
@@ -435,12 +470,12 @@ class TreeObsForRailEnv(ObservationBuilder):
                                                                           branch_direction,
                                                                           tot_dist + 1,
                                                                           depth + 1)
-                node.childs[self.tree_explorted_actions_char[i]] = branch_observation
+                node.childs[self.tree_explored_actions_char[i]] = branch_observation
                 if len(branch_visited) != 0:
                     visited |= branch_visited
             else:
                 # no exploring possible, add just cells with infinity
-                node.childs[self.tree_explorted_actions_char[i]] = -np.inf
+                node.childs[self.tree_explored_actions_char[i]] = -np.inf
 
         if depth == self.max_depth:
             node.childs.clear()
@@ -451,7 +486,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         Utility function to print tree observations returned by this object.
         """
         self.print_node_features(tree, "root", "")
-        for direction in self.tree_explorted_actions_char:
+        for direction in self.tree_explored_actions_char:
             self.print_subtree(tree.childs[direction], direction, "\t")
 
     @staticmethod
@@ -460,7 +495,8 @@ class TreeObsForRailEnv(ObservationBuilder):
               node.dist_other_target_encountered, ", ", node.dist_other_agent_encountered, ", ",
               node.dist_potential_conflict, ", ", node.dist_unusable_switch, ", ", node.dist_to_next_branch, ", ",
               node.dist_min_to_target, ", ", node.num_agents_same_direction, ", ", node.num_agents_opposite_direction,
-              ", ", node.num_agents_malfunctioning, ", ", node.speed_min_fractional)
+              ", ", node.num_agents_malfunctioning, ", ", node.speed_min_fractional, ", ",
+              node.num_agents_ready_to_depart)
 
     def print_subtree(self, node, label, indent):
         if node == -np.inf or not node:
@@ -472,7 +508,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         if not node.childs:
             return
 
-        for direction in self.tree_explorted_actions_char:
+        for direction in self.tree_explored_actions_char:
             self.print_subtree(node.childs[direction], direction, indent + "\t")
 
     def set_env(self, env: Environment):
@@ -497,6 +533,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
             - second channel containing the other agents positions and diretion
             - third channel containing agent/other agent malfunctions
             - fourth channel containing agent/other agent fractional speeds
+            - fifth channel containing number of other agents ready to depart
 
         - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\
          target and the positions of the other agents targets.
@@ -518,18 +555,33 @@ class GlobalObsForRailEnv(ObservationBuilder):
 
     def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
 
+        agent = self.env.agents[handle]
+        if agent.status == RailAgentStatus.READY_TO_DEPART:
+            _agent_initial_position = agent.initial_position
+        elif agent.status == RailAgentStatus.ACTIVE:
+            _agent_initial_position = agent.position
+        elif agent.status == RailAgentStatus.DONE:
+            _agent_initial_position = agent.target
+        else:
+            return None
+
         obs_targets = np.zeros((self.env.height, self.env.width, 2))
-        obs_agents_state = np.zeros((self.env.height, self.env.width, 4)) - 1
+        obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - 1
 
-        agent = self.env.agents[handle]
-        obs_agents_state[agent.position][0] = agent.direction
+        obs_agents_state[_agent_initial_position][0] = agent.direction
         obs_targets[agent.target][0] = 1
 
         for i in range(len(self.env.agents)):
-            other_agent = self.env.agents[i]
+            other_agent: EnvAgent = self.env.agents[i]
+            # ignore other_agent if it is not in the grid
+            if other_agent.position is None:
+                continue
             if i != handle:
                 obs_agents_state[other_agent.position][1] = other_agent.direction
                 obs_targets[other_agent.target][1] = 1
+                if other_agent.status == RailAgentStatus.READY_TO_DEPART:
+                    obs_agents_state[other_agent.initial_position] += 1
+
             obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction']
             obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed']
 
@@ -621,18 +673,14 @@ class LocalObsForRailEnv(ObservationBuilder):
         direction = np.identity(4)[agent.direction]
         return local_rail_obs, obs_map_state, obs_other_agents_state, direction
 
-    def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]:
+    def get_many(self, handles: Optional[List[int]] = None) -> Dict[
+        int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]:
         """
         Called whenever an observation has to be computed for the `env` environment, for each agent with handle
         in the `handles` list.
         """
 
-        observations = {}
-        if handles is None:
-            handles = []
-        for h in handles:
-            observations[h] = self.get(h)
-        return observations
+        return super().get_many(handles)
 
     def field_of_view(self, position, direction, state=None):
         # Compute the local field of view for an agent in the environment
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index 76095a2a2e1d9532951600118c6a777612641101..29b6947c28fa054cd1f4c44a204c740e4f536181 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -5,6 +5,7 @@ Collection of environment-specific PredictionBuilder.
 import numpy as np
 
 from flatland.core.env_prediction_builder import PredictionBuilder
+from flatland.envs.agent_utils import RailAgentStatus
 from flatland.envs.distance_map import DistanceMap
 from flatland.envs.rail_env import RailEnvActions
 from flatland.envs.rail_env_shortest_paths import get_shortest_paths
@@ -47,6 +48,9 @@ class DummyPredictorForRailEnv(PredictionBuilder):
         prediction_dict = {}
 
         for agent in agents:
+            if agent.status != RailAgentStatus.ACTIVE:
+                # TODO make this generic
+                continue
             action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
             _agent_initial_position = agent.position
             _agent_initial_direction = agent.direction
@@ -122,7 +126,17 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
 
         prediction_dict = {}
         for agent in agents:
-            _agent_initial_position = agent.position
+
+            if agent.status == RailAgentStatus.READY_TO_DEPART:
+                _agent_initial_position = agent.initial_position
+            elif agent.status == RailAgentStatus.ACTIVE:
+                _agent_initial_position = agent.position
+            elif agent.status == RailAgentStatus.DONE:
+                    _agent_initial_position = agent.target
+            else:
+                prediction_dict[agent.handle] = None
+                continue
+
             _agent_initial_direction = agent.direction
             agent_speed = agent.speed_data["speed"]
             times_per_cell = int(np.reciprocal(agent_speed))
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 4fb1c55e5a179e99e6c4985fff3c5596ba1f0a66..1b5ca23f36e3571aa75d9fd69431c119054ab05c 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -15,7 +15,7 @@ from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.transition_map import GridTransitionMap
-from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
+from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus
 from flatland.envs.distance_map import DistanceMap
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_generators import random_rail_generator, RailGenerator
@@ -224,12 +224,18 @@ class RailEnv(Environment):
         self.agents_static.append(agent_static)
         return len(self.agents_static) - 1
 
+    def set_agent_active(self, handle: int):
+        agent = self.agents[handle]
+        if agent.status == RailAgentStatus.READY_TO_DEPART and self.cell_free(agent.initial_position):
+            agent.status = RailAgentStatus.ACTIVE
+            agent.position = agent.initial_position
+
     def restart_agents(self):
         """ Reset the agents to their starting positions defined in agents_static
         """
         self.agents = EnvAgent.list_from_static(self.agents_static)
 
-    def reset(self, regen_rail=True, replace_agents=True):
+    def reset(self, regen_rail=True, replace_agents=True, activate_agents=False):
         """ if regen_rail then regenerate the rails.
             if replace_agents then regenerate the agents static.
             Relies on the rail_generator returning agent_static lists (pos, dir, target)
@@ -265,8 +271,13 @@ class RailEnv(Environment):
                 *self.schedule_generator(self.rail, self.get_num_agents(), agents_hints))
         self.restart_agents()
 
-        for i_agent in range(self.get_num_agents()):
-            agent = self.agents[i_agent]
+        if activate_agents:
+            for i_agent in range(self.get_num_agents()):
+                self.set_agent_active(i_agent)
+
+        for i_agent, agent in enumerate(self.agents):
+            if agent.status != RailAgentStatus.ACTIVE:
+                continue
 
             # A proportion of agent in the environment will receive a positive malfunction rate
             if np.random.random() < self.proportion_malfunctioning_trains:
@@ -354,7 +365,8 @@ class RailEnv(Environment):
             info_dict = {
                 'action_required': {i: False for i in range(self.get_num_agents())},
                 'malfunction': {i: 0 for i in range(self.get_num_agents())},
-                'speed': {i: 0 for i in range(self.get_num_agents())}
+                'speed': {i: 0 for i in range(self.get_num_agents())},
+                'status': {i: agent.status for i, agent in enumerate(self.agents)}
             }
             return self._get_observations(), self.rewards_dict, self.dones, info_dict
 
@@ -369,21 +381,19 @@ class RailEnv(Environment):
 
         if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps):
             self.dones["__all__"] = True
-            for k in self.dones.keys():
-                self.dones[k] = True
-
-        action_required_agents = {
-            i: self.agents[i].speed_data['position_fraction'] == 0.0 for i in range(self.get_num_agents())
-        }
-        malfunction_agents = {
-            i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents())
-        }
-        speed_agents = {i: self.agents[i].speed_data['speed'] for i in range(self.get_num_agents())}
+            for i in range(self.get_num_agents()):
+                self.agents[i].status = RailAgentStatus.DONE
+                self.dones[i] = True
 
         info_dict = {
-            'action_required': action_required_agents,
-            'malfunction': malfunction_agents,
-            'speed': speed_agents
+            'action_required': {
+                i: (agent.status == RailAgentStatus.ACTIVE and agent.speed_data['position_fraction'] == 0.0)
+                for i, agent in enumerate(self.agents)},
+            'malfunction': {
+                i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents())
+            },
+            'speed': {i: self.agents[i].speed_data['speed'] for i in range(self.get_num_agents())},
+            'status': {i: agent.status for i, agent in enumerate(self.agents)}
         }
 
         return self._get_observations(), self.rewards_dict, self.dones, info_dict
@@ -401,10 +411,19 @@ class RailEnv(Environment):
         action_dict_ : Dict[int,RailEnvActions]
 
         """
-        if self.dones[i_agent]:  # this agent has already completed...
+        agent = self.agents[i_agent]
+        if agent.status == RailAgentStatus.DONE:  # this agent has already completed...
             return
 
-        agent = self.agents[i_agent]
+        # agent gets active by a MOVE_* action and if c
+        if agent.status == RailAgentStatus.READY_TO_DEPART:
+            if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT,
+                          RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position):
+                agent.status = RailAgentStatus.ACTIVE
+                agent.position = agent.initial_position
+            else:
+                return
+
         agent.old_direction = agent.direction
         agent.old_position = agent.position
 
@@ -497,6 +516,7 @@ class RailEnv(Environment):
 
             # has the agent reached its target?
             if np.equal(agent.position, agent.target).all():
+                agent.status = RailAgentStatus.DONE
                 self.dones[i_agent] = True
                 agent.moving = False
             else:
@@ -543,9 +563,15 @@ class RailEnv(Environment):
 
         # Check the new position is not the same as any of the existing agent positions
         # (including itself, for simplicity, since it is moving)
-        cell_free = not np.any(np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
+        cell_free = self.cell_free(new_position)
         return cell_free, new_cell_valid, new_direction, new_position, transition_valid
 
+    def cell_free(self, position):
+
+        agent_positions = [agent.position for agent in self.agents if agent.position is not None]
+        ret = len(agent_positions) == 0 or not np.any(np.equal(position, agent_positions).all(1))
+        return ret
+
     def check_action(self, agent: EnvAgent, action: RailEnvActions):
         """
 
@@ -591,7 +617,7 @@ class RailEnv(Environment):
         return self.obs_dict
 
     def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]:
-        return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row,col))
+        return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col))
 
     def get_full_state_msg(self):
         grid_data = self.rail.grid.tolist()
diff --git a/flatland/envs/rail_env_shortest_paths.py b/flatland/envs/rail_env_shortest_paths.py
index 793601d4d18ac38b729d15883089d5acbfc41ed3..7944a49daf388403c8a226dbf866ac810d1286a3 100644
--- a/flatland/envs/rail_env_shortest_paths.py
+++ b/flatland/envs/rail_env_shortest_paths.py
@@ -7,6 +7,7 @@ import numpy as np
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.transition_map import GridTransitionMap
+from flatland.envs.agent_utils import RailAgentStatus
 from flatland.envs.distance_map import DistanceMap
 from flatland.envs.rail_env import RailEnvNextAction, RailEnvActions
 from flatland.utils.ordered_set import OrderedSet
@@ -92,7 +93,15 @@ def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = Non
     shortest_paths = dict()
 
     def _shortest_path_for_agent(agent):
-        position = agent.position
+        if agent.status == RailAgentStatus.READY_TO_DEPART:
+            position = agent.initial_position
+        elif agent.status == RailAgentStatus.ACTIVE:
+            position = agent.position
+        elif agent.status == RailAgentStatus.DONE:
+            position = agent.target
+        else:
+            shortest_paths[agent.handle] = None
+            return
         direction = agent.direction
         shortest_paths[agent.handle] = []
         distance = math.inf
diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py
index 7f42feeacd0ef50b56846540a9b2af9d147eafb0..5442a0af191ce31e415964224557cb18d05407f1 100644
--- a/flatland/envs/schedule_generators.py
+++ b/flatland/envs/schedule_generators.py
@@ -235,7 +235,7 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator:
             agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3]) for d in data["agents_static"]]
 
         # setup with loaded data
-        agents_position = [a.position for a in agents_static]
+        agents_position = [a.initial_position for a in agents_static]
         agents_direction = [a.direction for a in agents_static]
         agents_target = [a.target for a in agents_static]
         if len(data['agents_static'][0]) > 5:
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 99958bf38449ef8eb58c519990f1975106409c4e..e7b1e72679937bc5b3093cebfa58bd2e9894ba94 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -146,6 +146,9 @@ class RenderTool(object):
         Plot a simple agent.
         Assumes a working graphics layer context (cf a MPL figure).
         """
+        if position_row_col is None:
+            return
+
         rt = self.__class__
 
         direction_row_col = rt.transitions_row_col[direction]  # agent direction in RC
@@ -535,7 +538,7 @@ class RenderTool(object):
 
         for agent_idx, agent in enumerate(self.env.agents):
 
-            if agent is None:
+            if agent is None or agent.position is None:
                 continue
 
             if self.agent_render_variant == AgentRenderVariant.BOX_ONLY:
diff --git a/tests/test_flaltland_rail_agent_status.py b/tests/test_flaltland_rail_agent_status.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4396c101d1879bf8f5f74f9a79f9a8541d45e31
--- /dev/null
+++ b/tests/test_flaltland_rail_agent_status.py
@@ -0,0 +1,124 @@
+import numpy as np
+
+from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.envs.agent_utils import RailAgentStatus
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv, RailEnvActions
+from flatland.envs.rail_generators import rail_from_grid_transition_map
+from flatland.envs.schedule_generators import random_schedule_generator
+from flatland.utils.simple_rail import make_simple_rail
+from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay
+
+np.random.seed(1)
+
+
+def test_initial_status():
+    """Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
+    rail, rail_map = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
+                  number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  )
+
+    set_penalties_for_replay(env)
+    test_config = ReplayConfig(
+        replay=[
+            Replay(
+                position=None,  # not entered grid yet
+                direction=Grid4TransitionsEnum.EAST,
+                status=RailAgentStatus.READY_TO_DEPART,
+                action=RailEnvActions.DO_NOTHING,
+                reward=0,
+
+            ),
+            Replay(
+                position=None,  # not entered grid yet before step
+                direction=Grid4TransitionsEnum.EAST,
+                status=RailAgentStatus.READY_TO_DEPART,
+                action=RailEnvActions.MOVE_LEFT,
+                reward=env.start_penalty + env.step_penalty * 0.5,  # auto-correction left to forward without penalty!
+            ),
+            Replay(
+                position=(3, 9),
+                direction=Grid4TransitionsEnum.EAST,
+                status=RailAgentStatus.ACTIVE,
+                action=None,
+                reward=env.step_penalty * 0.5,  # running at speed 0.5
+            ),
+            Replay(
+                position=(3, 8),
+                direction=Grid4TransitionsEnum.WEST,
+                status=RailAgentStatus.ACTIVE,
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.step_penalty * 0.5,  # running at speed 0.5
+            ),
+            Replay(
+                position=(3, 8),
+                direction=Grid4TransitionsEnum.WEST,
+                status=RailAgentStatus.ACTIVE,
+                action=None,
+                reward=env.step_penalty * 0.5,  # running at speed 0.5
+            ),
+            Replay(
+                position=(3, 7),
+                direction=Grid4TransitionsEnum.WEST,
+                status=RailAgentStatus.ACTIVE,
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.step_penalty * 0.5,  # running at speed 0.5
+
+            ),
+            Replay(
+                position=(3, 7),
+                direction=Grid4TransitionsEnum.WEST,
+                action=None,
+                reward=env.step_penalty * 0.5,  # running at speed 0.5
+                status=RailAgentStatus.ACTIVE
+            ),
+            Replay(
+                position=(3, 6),
+                direction=Grid4TransitionsEnum.WEST,
+                action=RailEnvActions.MOVE_RIGHT,
+                reward=env.step_penalty * 0.5,  # wrong action is corrected to forward without penalty!
+                status=RailAgentStatus.ACTIVE
+            ),
+            Replay(
+                position=(3, 6),
+                direction=Grid4TransitionsEnum.WEST,
+                action=None,
+                reward=env.global_reward,  # done
+                status=RailAgentStatus.ACTIVE
+            ),
+            Replay(
+                position=(3, 5),
+                direction=Grid4TransitionsEnum.WEST,
+                action=None,
+                reward=env.global_reward,  # already done
+                status=RailAgentStatus.DONE
+            ),
+            Replay(
+                position=(3, 5),
+                direction=Grid4TransitionsEnum.WEST,
+                action=None,
+                reward=env.global_reward,  # already done
+                status=RailAgentStatus.DONE
+            ),
+            Replay(
+                position=(3, 5),
+                direction=Grid4TransitionsEnum.WEST,
+                action=None,
+                reward=env.global_reward,  # already done
+                status=RailAgentStatus.DONE
+            )
+
+        ],
+        initial_position=(3, 9),  # east dead-end
+        initial_direction=Grid4TransitionsEnum.EAST,
+        target=(3, 5),
+        speed=0.5
+    )
+
+    run_replay_config(env, [test_config], activate_agents=False)
diff --git a/tests/test_flatland_envs_city_generator.py b/tests/test_flatland_envs_city_generator.py
index fe39d785712c88f03af09c7a6d1dac715a585db3..1d386df225e7d025116752e26d5c55cf2a292214 100644
--- a/tests/test_flatland_envs_city_generator.py
+++ b/tests/test_flatland_envs_city_generator.py
@@ -28,274 +28,274 @@ def test_city_generator():
 
     expected_grid_map = np.zeros((50, 50), dtype=env.rail.transitions.get_type())
 
-    expected_grid_map[8][16]=4
-    expected_grid_map[8][17]=5633
-    expected_grid_map[8][18]=1025
-    expected_grid_map[8][19]=1025
-    expected_grid_map[8][20]=17411
-    expected_grid_map[8][21]=1025
-    expected_grid_map[8][22]=1025
-    expected_grid_map[8][23]=1025
-    expected_grid_map[8][24]=1025
-    expected_grid_map[8][25]=1025
-    expected_grid_map[8][26]=4608
-    expected_grid_map[9][16]=16386
-    expected_grid_map[9][17]=50211
-    expected_grid_map[9][18]=1025
-    expected_grid_map[9][19]=1025
-    expected_grid_map[9][20]=3089
-    expected_grid_map[9][21]=1025
-    expected_grid_map[9][22]=256
-    expected_grid_map[9][26]=32800
-    expected_grid_map[10][6]=16386
-    expected_grid_map[10][7]=1025
-    expected_grid_map[10][8]=1025
-    expected_grid_map[10][9]=1025
-    expected_grid_map[10][10]=1025
-    expected_grid_map[10][11]=1025
-    expected_grid_map[10][12]=1025
-    expected_grid_map[10][13]=1025
-    expected_grid_map[10][14]=1025
-    expected_grid_map[10][15]=1025
-    expected_grid_map[10][16]=33825
-    expected_grid_map[10][17]=34864
-    expected_grid_map[10][26]=32800
-    expected_grid_map[11][6]=32800
-    expected_grid_map[11][16]=32800
-    expected_grid_map[11][17]=32800
-    expected_grid_map[11][26]=32800
-    expected_grid_map[12][6]=32800
-    expected_grid_map[12][16]=32800
-    expected_grid_map[12][17]=32800
-    expected_grid_map[12][26]=32800
-    expected_grid_map[13][6]=32800
-    expected_grid_map[13][16]=32800
-    expected_grid_map[13][17]=32800
-    expected_grid_map[13][26]=32800
-    expected_grid_map[14][6]=32800
-    expected_grid_map[14][16]=32800
-    expected_grid_map[14][17]=32800
-    expected_grid_map[14][26]=32800
-    expected_grid_map[15][6]=32800
-    expected_grid_map[15][16]=32800
-    expected_grid_map[15][17]=32800
-    expected_grid_map[15][26]=32800
-    expected_grid_map[16][6]=32800
-    expected_grid_map[16][16]=32800
-    expected_grid_map[16][17]=32800
-    expected_grid_map[16][26]=32800
-    expected_grid_map[17][6]=32800
-    expected_grid_map[17][16]=72
-    expected_grid_map[17][17]=1097
-    expected_grid_map[17][18]=1025
-    expected_grid_map[17][19]=1025
-    expected_grid_map[17][20]=1025
-    expected_grid_map[17][21]=1025
-    expected_grid_map[17][22]=1025
-    expected_grid_map[17][23]=1025
-    expected_grid_map[17][24]=1025
-    expected_grid_map[17][25]=1025
-    expected_grid_map[17][26]=33825
-    expected_grid_map[17][27]=4608
-    expected_grid_map[18][6]=32800
-    expected_grid_map[18][26]=72
-    expected_grid_map[18][27]=52275
-    expected_grid_map[18][28]=5633
-    expected_grid_map[18][29]=17411
-    expected_grid_map[18][30]=1025
-    expected_grid_map[18][31]=1025
-    expected_grid_map[18][32]=256
-    expected_grid_map[19][6]=32800
-    expected_grid_map[19][25]=16386
-    expected_grid_map[19][26]=1025
-    expected_grid_map[19][27]=2136
-    expected_grid_map[19][28]=1097
-    expected_grid_map[19][29]=1097
-    expected_grid_map[19][30]=5633
-    expected_grid_map[19][31]=1025
-    expected_grid_map[19][32]=256
-    expected_grid_map[20][6]=32800
-    expected_grid_map[20][25]=32800
-    expected_grid_map[20][26]=16386
-    expected_grid_map[20][27]=17411
-    expected_grid_map[20][28]=1025
-    expected_grid_map[20][29]=1025
-    expected_grid_map[20][30]=3089
-    expected_grid_map[20][31]=1025
-    expected_grid_map[20][32]=256
-    expected_grid_map[21][6]=32800
-    expected_grid_map[21][16]=16386
-    expected_grid_map[21][17]=1025
-    expected_grid_map[21][18]=1025
-    expected_grid_map[21][19]=1025
-    expected_grid_map[21][20]=1025
-    expected_grid_map[21][21]=1025
-    expected_grid_map[21][22]=1025
-    expected_grid_map[21][23]=1025
-    expected_grid_map[21][24]=1025
-    expected_grid_map[21][25]=33825
-    expected_grid_map[21][26]=33825
-    expected_grid_map[21][27]=2064
-    expected_grid_map[22][6]=32800
-    expected_grid_map[22][16]=32800
-    expected_grid_map[22][25]=32800
-    expected_grid_map[22][26]=32800
-    expected_grid_map[23][6]=32800
-    expected_grid_map[23][16]=32800
-    expected_grid_map[23][25]=32800
-    expected_grid_map[23][26]=32800
-    expected_grid_map[24][6]=32800
-    expected_grid_map[24][16]=32800
-    expected_grid_map[24][25]=32800
-    expected_grid_map[24][26]=32800
-    expected_grid_map[25][6]=32800
-    expected_grid_map[25][16]=32800
-    expected_grid_map[25][25]=32800
-    expected_grid_map[25][26]=32800
-    expected_grid_map[26][6]=32800
-    expected_grid_map[26][16]=32800
-    expected_grid_map[26][25]=32800
-    expected_grid_map[26][26]=32800
-    expected_grid_map[27][6]=72
-    expected_grid_map[27][7]=1025
-    expected_grid_map[27][8]=1025
-    expected_grid_map[27][9]=17411
-    expected_grid_map[27][10]=1025
-    expected_grid_map[27][11]=1025
-    expected_grid_map[27][12]=1025
-    expected_grid_map[27][13]=1025
-    expected_grid_map[27][14]=1025
-    expected_grid_map[27][15]=4608
-    expected_grid_map[27][16]=72
-    expected_grid_map[27][17]=17411
-    expected_grid_map[27][18]=5633
-    expected_grid_map[27][19]=1025
-    expected_grid_map[27][20]=1025
-    expected_grid_map[27][21]=1025
-    expected_grid_map[27][22]=1025
-    expected_grid_map[27][23]=1025
-    expected_grid_map[27][24]=1025
-    expected_grid_map[27][25]=33825
-    expected_grid_map[27][26]=2064
-    expected_grid_map[28][6]=4
-    expected_grid_map[28][7]=1025
-    expected_grid_map[28][8]=1025
-    expected_grid_map[28][9]=3089
-    expected_grid_map[28][10]=1025
-    expected_grid_map[28][11]=1025
-    expected_grid_map[28][12]=1025
-    expected_grid_map[28][13]=1025
-    expected_grid_map[28][14]=4608
-    expected_grid_map[28][15]=72
-    expected_grid_map[28][16]=1025
-    expected_grid_map[28][17]=2136
-    expected_grid_map[28][18]=1097
-    expected_grid_map[28][19]=5633
-    expected_grid_map[28][20]=5633
-    expected_grid_map[28][21]=1025
-    expected_grid_map[28][22]=256
-    expected_grid_map[28][25]=32800
-    expected_grid_map[29][6]=4
-    expected_grid_map[29][7]=5633
-    expected_grid_map[29][8]=20994
-    expected_grid_map[29][9]=5633
-    expected_grid_map[29][10]=1025
-    expected_grid_map[29][11]=1025
-    expected_grid_map[29][12]=1025
-    expected_grid_map[29][13]=1025
-    expected_grid_map[29][14]=1097
-    expected_grid_map[29][15]=5633
-    expected_grid_map[29][16]=1025
-    expected_grid_map[29][17]=17411
-    expected_grid_map[29][18]=5633
-    expected_grid_map[29][19]=1097
-    expected_grid_map[29][20]=3089
-    expected_grid_map[29][21]=20994
-    expected_grid_map[29][22]=1025
-    expected_grid_map[29][23]=1025
-    expected_grid_map[29][24]=1025
-    expected_grid_map[29][25]=2064
-    expected_grid_map[30][6]=16386
-    expected_grid_map[30][7]=38505
-    expected_grid_map[30][8]=3089
-    expected_grid_map[30][9]=1097
-    expected_grid_map[30][10]=1025
-    expected_grid_map[30][11]=1025
-    expected_grid_map[30][12]=256
-    expected_grid_map[30][15]=32800
-    expected_grid_map[30][16]=16386
-    expected_grid_map[30][17]=52275
-    expected_grid_map[30][18]=1097
-    expected_grid_map[30][19]=1025
-    expected_grid_map[30][20]=1025
-    expected_grid_map[30][21]=3089
-    expected_grid_map[30][22]=256
-    expected_grid_map[31][6]=32800
-    expected_grid_map[31][7]=32800
-    expected_grid_map[31][15]=72
-    expected_grid_map[31][16]=37408
-    expected_grid_map[31][17]=32800
-    expected_grid_map[32][6]=32800
-    expected_grid_map[32][7]=32800
-    expected_grid_map[32][16]=32800
-    expected_grid_map[32][17]=32800
-    expected_grid_map[33][6]=32800
-    expected_grid_map[33][7]=32800
-    expected_grid_map[33][16]=32800
-    expected_grid_map[33][17]=32800
-    expected_grid_map[34][6]=32800
-    expected_grid_map[34][7]=32800
-    expected_grid_map[34][16]=32800
-    expected_grid_map[34][17]=32800
-    expected_grid_map[35][6]=32800
-    expected_grid_map[35][7]=32800
-    expected_grid_map[35][16]=32800
-    expected_grid_map[35][17]=32800
-    expected_grid_map[36][6]=32800
-    expected_grid_map[36][7]=32800
-    expected_grid_map[36][16]=32800
-    expected_grid_map[36][17]=32800
-    expected_grid_map[37][6]=72
-    expected_grid_map[37][7]=1097
-    expected_grid_map[37][8]=1025
-    expected_grid_map[37][9]=1025
-    expected_grid_map[37][10]=1025
-    expected_grid_map[37][11]=1025
-    expected_grid_map[37][12]=1025
-    expected_grid_map[37][13]=1025
-    expected_grid_map[37][14]=1025
-    expected_grid_map[37][15]=1025
-    expected_grid_map[37][16]=33897
-    expected_grid_map[37][17]=37408
-    expected_grid_map[38][16]=72
-    expected_grid_map[38][17]=52275
-    expected_grid_map[38][18]=5633
-    expected_grid_map[38][19]=17411
-    expected_grid_map[38][20]=1025
-    expected_grid_map[38][21]=1025
-    expected_grid_map[38][22]=256
-    expected_grid_map[39][16]=4
-    expected_grid_map[39][17]=52275
-    expected_grid_map[39][18]=3089
-    expected_grid_map[39][19]=1097
-    expected_grid_map[39][20]=5633
-    expected_grid_map[39][21]=1025
-    expected_grid_map[39][22]=256
-    expected_grid_map[40][16]=4
-    expected_grid_map[40][17]=1097
-    expected_grid_map[40][18]=1025
-    expected_grid_map[40][19]=1025
-    expected_grid_map[40][20]=3089
-    expected_grid_map[40][21]=1025
-    expected_grid_map[40][22]=256
+    expected_grid_map[8][16] = 4
+    expected_grid_map[8][17] = 5633
+    expected_grid_map[8][18] = 1025
+    expected_grid_map[8][19] = 1025
+    expected_grid_map[8][20] = 17411
+    expected_grid_map[8][21] = 1025
+    expected_grid_map[8][22] = 1025
+    expected_grid_map[8][23] = 1025
+    expected_grid_map[8][24] = 1025
+    expected_grid_map[8][25] = 1025
+    expected_grid_map[8][26] = 4608
+    expected_grid_map[9][16] = 16386
+    expected_grid_map[9][17] = 50211
+    expected_grid_map[9][18] = 1025
+    expected_grid_map[9][19] = 1025
+    expected_grid_map[9][20] = 3089
+    expected_grid_map[9][21] = 1025
+    expected_grid_map[9][22] = 256
+    expected_grid_map[9][26] = 32800
+    expected_grid_map[10][6] = 16386
+    expected_grid_map[10][7] = 1025
+    expected_grid_map[10][8] = 1025
+    expected_grid_map[10][9] = 1025
+    expected_grid_map[10][10] = 1025
+    expected_grid_map[10][11] = 1025
+    expected_grid_map[10][12] = 1025
+    expected_grid_map[10][13] = 1025
+    expected_grid_map[10][14] = 1025
+    expected_grid_map[10][15] = 1025
+    expected_grid_map[10][16] = 33825
+    expected_grid_map[10][17] = 34864
+    expected_grid_map[10][26] = 32800
+    expected_grid_map[11][6] = 32800
+    expected_grid_map[11][16] = 32800
+    expected_grid_map[11][17] = 32800
+    expected_grid_map[11][26] = 32800
+    expected_grid_map[12][6] = 32800
+    expected_grid_map[12][16] = 32800
+    expected_grid_map[12][17] = 32800
+    expected_grid_map[12][26] = 32800
+    expected_grid_map[13][6] = 32800
+    expected_grid_map[13][16] = 32800
+    expected_grid_map[13][17] = 32800
+    expected_grid_map[13][26] = 32800
+    expected_grid_map[14][6] = 32800
+    expected_grid_map[14][16] = 32800
+    expected_grid_map[14][17] = 32800
+    expected_grid_map[14][26] = 32800
+    expected_grid_map[15][6] = 32800
+    expected_grid_map[15][16] = 32800
+    expected_grid_map[15][17] = 32800
+    expected_grid_map[15][26] = 32800
+    expected_grid_map[16][6] = 32800
+    expected_grid_map[16][16] = 32800
+    expected_grid_map[16][17] = 32800
+    expected_grid_map[16][26] = 32800
+    expected_grid_map[17][6] = 32800
+    expected_grid_map[17][16] = 72
+    expected_grid_map[17][17] = 1097
+    expected_grid_map[17][18] = 1025
+    expected_grid_map[17][19] = 1025
+    expected_grid_map[17][20] = 1025
+    expected_grid_map[17][21] = 1025
+    expected_grid_map[17][22] = 1025
+    expected_grid_map[17][23] = 1025
+    expected_grid_map[17][24] = 1025
+    expected_grid_map[17][25] = 1025
+    expected_grid_map[17][26] = 33825
+    expected_grid_map[17][27] = 4608
+    expected_grid_map[18][6] = 32800
+    expected_grid_map[18][26] = 72
+    expected_grid_map[18][27] = 52275
+    expected_grid_map[18][28] = 5633
+    expected_grid_map[18][29] = 17411
+    expected_grid_map[18][30] = 1025
+    expected_grid_map[18][31] = 1025
+    expected_grid_map[18][32] = 256
+    expected_grid_map[19][6] = 32800
+    expected_grid_map[19][25] = 16386
+    expected_grid_map[19][26] = 1025
+    expected_grid_map[19][27] = 2136
+    expected_grid_map[19][28] = 1097
+    expected_grid_map[19][29] = 1097
+    expected_grid_map[19][30] = 5633
+    expected_grid_map[19][31] = 1025
+    expected_grid_map[19][32] = 256
+    expected_grid_map[20][6] = 32800
+    expected_grid_map[20][25] = 32800
+    expected_grid_map[20][26] = 16386
+    expected_grid_map[20][27] = 17411
+    expected_grid_map[20][28] = 1025
+    expected_grid_map[20][29] = 1025
+    expected_grid_map[20][30] = 3089
+    expected_grid_map[20][31] = 1025
+    expected_grid_map[20][32] = 256
+    expected_grid_map[21][6] = 32800
+    expected_grid_map[21][16] = 16386
+    expected_grid_map[21][17] = 1025
+    expected_grid_map[21][18] = 1025
+    expected_grid_map[21][19] = 1025
+    expected_grid_map[21][20] = 1025
+    expected_grid_map[21][21] = 1025
+    expected_grid_map[21][22] = 1025
+    expected_grid_map[21][23] = 1025
+    expected_grid_map[21][24] = 1025
+    expected_grid_map[21][25] = 33825
+    expected_grid_map[21][26] = 33825
+    expected_grid_map[21][27] = 2064
+    expected_grid_map[22][6] = 32800
+    expected_grid_map[22][16] = 32800
+    expected_grid_map[22][25] = 32800
+    expected_grid_map[22][26] = 32800
+    expected_grid_map[23][6] = 32800
+    expected_grid_map[23][16] = 32800
+    expected_grid_map[23][25] = 32800
+    expected_grid_map[23][26] = 32800
+    expected_grid_map[24][6] = 32800
+    expected_grid_map[24][16] = 32800
+    expected_grid_map[24][25] = 32800
+    expected_grid_map[24][26] = 32800
+    expected_grid_map[25][6] = 32800
+    expected_grid_map[25][16] = 32800
+    expected_grid_map[25][25] = 32800
+    expected_grid_map[25][26] = 32800
+    expected_grid_map[26][6] = 32800
+    expected_grid_map[26][16] = 32800
+    expected_grid_map[26][25] = 32800
+    expected_grid_map[26][26] = 32800
+    expected_grid_map[27][6] = 72
+    expected_grid_map[27][7] = 1025
+    expected_grid_map[27][8] = 1025
+    expected_grid_map[27][9] = 17411
+    expected_grid_map[27][10] = 1025
+    expected_grid_map[27][11] = 1025
+    expected_grid_map[27][12] = 1025
+    expected_grid_map[27][13] = 1025
+    expected_grid_map[27][14] = 1025
+    expected_grid_map[27][15] = 4608
+    expected_grid_map[27][16] = 72
+    expected_grid_map[27][17] = 17411
+    expected_grid_map[27][18] = 5633
+    expected_grid_map[27][19] = 1025
+    expected_grid_map[27][20] = 1025
+    expected_grid_map[27][21] = 1025
+    expected_grid_map[27][22] = 1025
+    expected_grid_map[27][23] = 1025
+    expected_grid_map[27][24] = 1025
+    expected_grid_map[27][25] = 33825
+    expected_grid_map[27][26] = 2064
+    expected_grid_map[28][6] = 4
+    expected_grid_map[28][7] = 1025
+    expected_grid_map[28][8] = 1025
+    expected_grid_map[28][9] = 3089
+    expected_grid_map[28][10] = 1025
+    expected_grid_map[28][11] = 1025
+    expected_grid_map[28][12] = 1025
+    expected_grid_map[28][13] = 1025
+    expected_grid_map[28][14] = 4608
+    expected_grid_map[28][15] = 72
+    expected_grid_map[28][16] = 1025
+    expected_grid_map[28][17] = 2136
+    expected_grid_map[28][18] = 1097
+    expected_grid_map[28][19] = 5633
+    expected_grid_map[28][20] = 5633
+    expected_grid_map[28][21] = 1025
+    expected_grid_map[28][22] = 256
+    expected_grid_map[28][25] = 32800
+    expected_grid_map[29][6] = 4
+    expected_grid_map[29][7] = 5633
+    expected_grid_map[29][8] = 20994
+    expected_grid_map[29][9] = 5633
+    expected_grid_map[29][10] = 1025
+    expected_grid_map[29][11] = 1025
+    expected_grid_map[29][12] = 1025
+    expected_grid_map[29][13] = 1025
+    expected_grid_map[29][14] = 1097
+    expected_grid_map[29][15] = 5633
+    expected_grid_map[29][16] = 1025
+    expected_grid_map[29][17] = 17411
+    expected_grid_map[29][18] = 5633
+    expected_grid_map[29][19] = 1097
+    expected_grid_map[29][20] = 3089
+    expected_grid_map[29][21] = 20994
+    expected_grid_map[29][22] = 1025
+    expected_grid_map[29][23] = 1025
+    expected_grid_map[29][24] = 1025
+    expected_grid_map[29][25] = 2064
+    expected_grid_map[30][6] = 16386
+    expected_grid_map[30][7] = 38505
+    expected_grid_map[30][8] = 3089
+    expected_grid_map[30][9] = 1097
+    expected_grid_map[30][10] = 1025
+    expected_grid_map[30][11] = 1025
+    expected_grid_map[30][12] = 256
+    expected_grid_map[30][15] = 32800
+    expected_grid_map[30][16] = 16386
+    expected_grid_map[30][17] = 52275
+    expected_grid_map[30][18] = 1097
+    expected_grid_map[30][19] = 1025
+    expected_grid_map[30][20] = 1025
+    expected_grid_map[30][21] = 3089
+    expected_grid_map[30][22] = 256
+    expected_grid_map[31][6] = 32800
+    expected_grid_map[31][7] = 32800
+    expected_grid_map[31][15] = 72
+    expected_grid_map[31][16] = 37408
+    expected_grid_map[31][17] = 32800
+    expected_grid_map[32][6] = 32800
+    expected_grid_map[32][7] = 32800
+    expected_grid_map[32][16] = 32800
+    expected_grid_map[32][17] = 32800
+    expected_grid_map[33][6] = 32800
+    expected_grid_map[33][7] = 32800
+    expected_grid_map[33][16] = 32800
+    expected_grid_map[33][17] = 32800
+    expected_grid_map[34][6] = 32800
+    expected_grid_map[34][7] = 32800
+    expected_grid_map[34][16] = 32800
+    expected_grid_map[34][17] = 32800
+    expected_grid_map[35][6] = 32800
+    expected_grid_map[35][7] = 32800
+    expected_grid_map[35][16] = 32800
+    expected_grid_map[35][17] = 32800
+    expected_grid_map[36][6] = 32800
+    expected_grid_map[36][7] = 32800
+    expected_grid_map[36][16] = 32800
+    expected_grid_map[36][17] = 32800
+    expected_grid_map[37][6] = 72
+    expected_grid_map[37][7] = 1097
+    expected_grid_map[37][8] = 1025
+    expected_grid_map[37][9] = 1025
+    expected_grid_map[37][10] = 1025
+    expected_grid_map[37][11] = 1025
+    expected_grid_map[37][12] = 1025
+    expected_grid_map[37][13] = 1025
+    expected_grid_map[37][14] = 1025
+    expected_grid_map[37][15] = 1025
+    expected_grid_map[37][16] = 33897
+    expected_grid_map[37][17] = 37408
+    expected_grid_map[38][16] = 72
+    expected_grid_map[38][17] = 52275
+    expected_grid_map[38][18] = 5633
+    expected_grid_map[38][19] = 17411
+    expected_grid_map[38][20] = 1025
+    expected_grid_map[38][21] = 1025
+    expected_grid_map[38][22] = 256
+    expected_grid_map[39][16] = 4
+    expected_grid_map[39][17] = 52275
+    expected_grid_map[39][18] = 3089
+    expected_grid_map[39][19] = 1097
+    expected_grid_map[39][20] = 5633
+    expected_grid_map[39][21] = 1025
+    expected_grid_map[39][22] = 256
+    expected_grid_map[40][16] = 4
+    expected_grid_map[40][17] = 1097
+    expected_grid_map[40][18] = 1025
+    expected_grid_map[40][19] = 1025
+    expected_grid_map[40][20] = 3089
+    expected_grid_map[40][21] = 1025
+    expected_grid_map[40][22] = 256
 
     assert np.array_equal(env.rail.grid, expected_grid_map), "actual={}, expected={}".format(env.rail.grid,
                                                                                              expected_grid_map)
-    
+
     s0 = 0
     s1 = 0
     for a in range(env.get_num_agents()):
-        s0 = Vec2d.get_manhattan_distance(env.agents[a].position, (0, 0))
-        s1 = Vec2d.get_chebyshev_distance(env.agents[a].position, (0, 0))
+        s0 = Vec2d.get_manhattan_distance(env.agents[a].initial_position, (0, 0))
+        s1 = Vec2d.get_chebyshev_distance(env.agents[a].initial_position, (0, 0))
     assert s0 == 58, "actual={}".format(s0)
     assert s1 == 38, "actual={}".format(s1)
diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py
index 0d6d309765690b1f95c681d7d109a13071d7f86b..52b047244850a5b1c6b39dadd74568fc5f92deec 100644
--- a/tests/test_flatland_envs_observations.py
+++ b/tests/test_flatland_envs_observations.py
@@ -5,7 +5,7 @@ import numpy as np
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.agent_utils import EnvAgent
+from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
 from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv, RailEnvActions
@@ -29,6 +29,9 @@ def test_global_obs():
 
     global_obs = env.reset()
 
+    # we have to take step for the agent to enter the grid.
+    global_obs, _, _, _ = env.step({0: RailEnvActions.MOVE_FORWARD})
+
     assert (global_obs[0][0].shape == rail_map.shape + (16,))
 
     rail_map_recons = np.zeros_like(rail_map)
@@ -109,12 +112,14 @@ def test_reward_function_conflict(rendering=False):
     agent.direction = 0  # north
     agent.target = (3, 9)  # east dead-end
     agent.moving = True
+    agent.status = RailAgentStatus.ACTIVE
 
     agent = env.agents_static[1]
     agent.position = (3, 8)  # east dead-end
     agent.direction = 3  # west
     agent.target = (6, 6)  # south dead-end
     agent.moving = True
+    agent.status = RailAgentStatus.ACTIVE
 
     # reset to set agents from agents_static
     env.reset(False, False)
@@ -184,16 +189,20 @@ def test_reward_function_waiting(rendering=False):
 
     # set the initial position
     agent = env.agents_static[0]
+    agent.initial_position = (3, 8)  # east dead-end
     agent.position = (3, 8)  # east dead-end
     agent.direction = 3  # west
     agent.target = (3, 1)  # west dead-end
     agent.moving = True
+    agent.status = RailAgentStatus.ACTIVE
 
     agent = env.agents_static[1]
+    agent.initial_position = (5, 6)  # south dead-end
     agent.position = (5, 6)  # south dead-end
     agent.direction = 0  # north
     agent.target = (3, 8)  # east dead-end
     agent.moving = True
+    agent.status = RailAgentStatus.ACTIVE
 
     # reset to set agents from agents_static
     env.reset(False, False)
diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py
index f4ab68bc45a82b8f196fcfbebb26fd68a36c37a4..2517fb84de00c1346e5c601f903a9a1a677cfae4 100644
--- a/tests/test_flatland_envs_predictions.py
+++ b/tests/test_flatland_envs_predictions.py
@@ -5,6 +5,7 @@ import pprint
 import numpy as np
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.envs.agent_utils import RailAgentStatus
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv, RailEnvActions, RailEnvNextAction
@@ -31,12 +32,13 @@ def test_dummy_predictor(rendering=False):
     env.reset()
 
     # set initial position and direction for testing...
-    env.agents_static[0].position = (5, 6)
+    env.agents_static[0].initial_position = (5, 6)
     env.agents_static[0].direction = 0
     env.agents_static[0].target = (3, 0)
 
     # reset to set agents from agents_static
     env.reset(False, False)
+    env.set_agent_active(0)
 
     if rendering:
         renderer = RenderTool(env, gl="PILSVG")
@@ -124,10 +126,12 @@ def test_shortest_path_predictor(rendering=False):
 
     # set the initial position
     agent = env.agents_static[0]
+    agent.initial_position = (5, 6)  # south dead-end
     agent.position = (5, 6)  # south dead-end
     agent.direction = 0  # north
     agent.target = (3, 9)  # east dead-end
     agent.moving = True
+    agent.status = RailAgentStatus.ACTIVE
 
     # reset to set agents from agents_static
     env.reset(False, False)
@@ -139,9 +143,9 @@ def test_shortest_path_predictor(rendering=False):
 
     # compute the observations and predictions
     distance_map = env.distance_map.get()
-    assert distance_map[0, agent.position[0], agent.position[
-        1], agent.direction] == 5.0, "found {} instead of {}".format(
-        distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0)
+    assert distance_map[0, agent.initial_position[0], agent.initial_position[1], agent.direction] == 5.0, \
+        "found {} instead of {}".format(
+            distance_map[agent.handle, agent.initial_position[0], agent.position[1], agent.direction], 5.0)
 
     paths = get_shortest_paths(env.distance_map)[0]
     assert paths == [
@@ -259,19 +263,23 @@ def test_shortest_path_predictor_conflicts(rendering=False):
 
     # set the initial position
     agent = env.agents_static[0]
+    agent.initial_position = (5, 6)  # south dead-end
     agent.position = (5, 6)  # south dead-end
     agent.direction = 0  # north
     agent.target = (3, 9)  # east dead-end
     agent.moving = True
+    agent.status = RailAgentStatus.ACTIVE
 
     agent = env.agents_static[1]
+    agent.initial_position = (3, 8)  # east dead-end
     agent.position = (3, 8)  # east dead-end
     agent.direction = 3  # west
     agent.target = (6, 6)  # south dead-end
     agent.moving = True
+    agent.status = RailAgentStatus.ACTIVE
 
     # reset to set agents from agents_static
-    observations = env.reset(False, False)
+    observations = env.reset(False, False, True)
 
     if rendering:
         renderer = RenderTool(env, gl="PILSVG")
@@ -295,14 +303,14 @@ def test_shortest_path_predictor_conflicts(rendering=False):
 
 def _check_expected_conflicts(expected_conflicts, obs_builder, tree: TreeObsForRailEnv.Node, prompt=''):
     assert (tree.num_agents_opposite_direction > 0) == (() in expected_conflicts), "{}[]".format(prompt)
-    for a_1 in obs_builder.tree_explorted_actions_char:
+    for a_1 in obs_builder.tree_explored_actions_char:
         if tree.childs[a_1] == -np.inf:
             assert False == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
             continue
         else:
             conflict = tree.childs[a_1].num_agents_opposite_direction
             assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
-        for a_2 in obs_builder.tree_explorted_actions_char:
+        for a_2 in obs_builder.tree_explored_actions_char:
             if tree.childs[a_1].childs[a_2] == -np.inf:
                 assert False == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
             else:
diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py
index 0fefd3e212ddb5f084c1e219f4063079e03dabdf..e0281bb0d3f21a8e25fdff86ade5bed05f5dea13 100644
--- a/tests/test_flatland_envs_rail_env.py
+++ b/tests/test_flatland_envs_rail_env.py
@@ -85,7 +85,7 @@ def test_rail_environment_single_agent():
                        obs_builder_object=GlobalObsForRailEnv())
 
     for _ in range(200):
-        _ = rail_env.reset()
+        _ = rail_env.reset(False, False, True)
 
         # We do not care about target for the moment
         agent = rail_env.agents[0]
@@ -130,9 +130,6 @@ def test_rail_environment_single_agent():
                 done = dones['__all__']
 
 
-test_rail_environment_single_agent()
-
-
 def test_dead_end():
     transitions = RailEnvTransitions()
 
@@ -164,32 +161,12 @@ def test_dead_end():
                        number_of_agents=1,
                        obs_builder_object=GlobalObsForRailEnv())
 
-    def check_consistency(rail_env):
-        # We run step to check that trains do not move anymore
-        # after being done.
-        # TODO: GIACOMO: this is deprecated and should be updated; thenew behavior is that agents keep moving
-        # until they are manually stopped.
-        for i in range(7):
-            prev_pos = rail_env.agents[0].position
-
-            # The train cannot turn, so we check that when it tries,
-            # it stays where it is.
-            _ = rail_env.step({0: 1})
-            _ = rail_env.step({0: 3})
-            assert (rail_env.agents[0].position == prev_pos)
-            _, _, dones, _ = rail_env.step({0: 2})
-
-            if i < 5:
-                assert (not dones[0] and not dones['__all__'])
-            else:
-                assert (dones[0] and dones['__all__'])
-
     # We try the configuration in the 4 directions:
     rail_env.reset()
-    rail_env.agents = [EnvAgent(position=(0, 2), direction=1, target=(0, 0), moving=False)]
+    rail_env.agents = [EnvAgent(initial_position=(0, 2), direction=1, target=(0, 0), moving=False)]
 
     rail_env.reset()
-    rail_env.agents = [EnvAgent(position=(0, 2), direction=3, target=(0, 4), moving=False)]
+    rail_env.agents = [EnvAgent(initial_position=(0, 2), direction=3, target=(0, 4), moving=False)]
 
     # In the vertical configuration:
     rail_map = np.array(
@@ -210,10 +187,12 @@ def test_dead_end():
                        obs_builder_object=GlobalObsForRailEnv())
 
     rail_env.reset()
-    rail_env.agents = [EnvAgent(position=(2, 0), direction=2, target=(0, 0), moving=False)]
+    rail_env.agents = [EnvAgent(initial_position=(2, 0), direction=2, target=(0, 0), moving=False)]
 
     rail_env.reset()
-    rail_env.agents = [EnvAgent(position=(2, 0), direction=0, target=(4, 0), moving=False)]
+    rail_env.agents = [EnvAgent(initial_position=(2, 0), direction=0, target=(4, 0), moving=False)]
+
+    # TODO make assertions
 
 
 def test_get_entry_directions():
diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py
index e164752483e2b4ad5896d754d378a5519c960237..8b2cdbea9431d970778acb8d973cc0002d5a90f5 100644
--- a/tests/test_flatland_envs_sparse_rail_generator.py
+++ b/tests/test_flatland_envs_sparse_rail_generator.py
@@ -25,6 +25,7 @@ def test_sparse_rail_generator():
                   schedule_generator=sparse_schedule_generator(),
                   number_of_agents=10,
                   obs_builder_object=GlobalObsForRailEnv())
+    env.reset(False, False, True)
     expected_grid_map = np.zeros((50, 50), dtype=env.rail.transitions.get_type())
     expected_grid_map[1][33] = 8192
     expected_grid_map[2][33] = 32800
@@ -1549,6 +1550,9 @@ def test_rail_env_action_required_info():
                                           obs_builder_object=GlobalObsForRailEnv())
     env_renderer = RenderTool(env_always_action, gl="PILSVG", )
 
+    env_always_action.reset(False, False, True)
+    env_only_if_action_required.reset(False, False, True)
+
     for step in range(100):
         print("step {}".format(step))
 
@@ -1610,6 +1614,7 @@ def test_rail_env_malfunction_speed_info():
                   number_of_agents=10,
                   obs_builder_object=GlobalObsForRailEnv(),
                   stochastic_data=stochastic_data)
+    env.reset(False, False, True)
 
     env_renderer = RenderTool(env, gl="PILSVG", )
     for step in range(100):
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index fa2920a9fa7c78331cc7c32ae5308633b4d3f8da..8b4fc8bb64b753532b96f0158d3fa754bf855e2b 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -2,14 +2,15 @@ import random
 from typing import Dict, List
 
 import numpy as np
-from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
 
 from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.grid.grid4_utils import get_new_position
+from flatland.envs.agent_utils import RailAgentStatus
 from flatland.envs.rail_env import RailEnv, RailEnvActions
 from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator
 from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator
+from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
 
 
 class SingleAgentNavigationObs(ObservationBuilder):
@@ -29,7 +30,16 @@ class SingleAgentNavigationObs(ObservationBuilder):
     def get(self, handle: int = 0) -> List[int]:
         agent = self.env.agents[handle]
 
-        possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
+        if agent.status == RailAgentStatus.READY_TO_DEPART:
+            _agent_initial_position = agent.initial_position
+        elif agent.status == RailAgentStatus.ACTIVE:
+            _agent_initial_position = agent.position
+        elif agent.status == RailAgentStatus.DONE:
+            _agent_initial_position = agent.target
+        else:
+            return None
+
+        possible_transitions = self.env.rail.get_transitions(*_agent_initial_position, agent.direction)
         num_transitions = np.count_nonzero(possible_transitions)
 
         # Start from the current orientation, and see which transitions are available;
@@ -41,7 +51,7 @@ class SingleAgentNavigationObs(ObservationBuilder):
             min_distances = []
             for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                 if possible_transitions[direction]:
-                    new_position = get_new_position(agent.position, direction)
+                    new_position = get_new_position(_agent_initial_position, direction)
                     min_distances.append(
                         self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
                 else:
@@ -70,16 +80,19 @@ def test_malfunction_process():
                   obs_builder_object=SingleAgentNavigationObs(),
                   stochastic_data=stochastic_data)
 
-    obs = env.reset()
+    obs = env.reset(False, False, True)
 
     # Check that a initial duration for malfunction was assigned
     assert env.agents[0].malfunction_data['next_malfunction'] > 0
+    for agent in env.agents:
+        agent.status = RailAgentStatus.ACTIVE
 
     agent_halts = 0
     total_down_time = 0
     agent_old_position = env.agents[0].position
     for step in range(100):
         actions = {}
+
         for i in range(len(obs)):
             actions[i] = np.argmax(obs[i]) + 1
 
@@ -104,7 +117,8 @@ def test_malfunction_process():
         total_down_time += env.agents[0].malfunction_data['malfunction']
 
     # Check that the appropriate number of malfunctions is achieved
-    assert env.agents[0].malfunction_data['nr_malfunctions'] == 21
+    assert env.agents[0].malfunction_data['nr_malfunctions'] == 21, "Actual {}".format(
+        env.agents[0].malfunction_data['nr_malfunctions'])
 
     # Check that 20 stops where performed
     assert agent_halts == 20
@@ -120,8 +134,6 @@ def test_malfunction_process_statistically():
                        'malfunction_rate': 2,
                        'min_duration': 3,
                        'max_duration': 3}
-    np.random.seed(5)
-    random.seed(0)
 
     env = RailEnv(width=20,
                   height=20,
@@ -131,8 +143,9 @@ def test_malfunction_process_statistically():
                   number_of_agents=2,
                   obs_builder_object=SingleAgentNavigationObs(),
                   stochastic_data=stochastic_data)
-
-    env.reset()
+    np.random.seed(5)
+    random.seed(0)
+    env.reset(False, False, True)
     nb_malfunction = 0
     for step in range(100):
         action_dict: Dict[int, RailEnvActions] = {}
@@ -149,9 +162,6 @@ def test_malfunction_process_statistically():
 
 
 def test_initial_malfunction():
-    random.seed(0)
-    np.random.seed(0)
-
     stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                        'malfunction_rate': 70,  # Rate of malfunction occurence
                        'min_duration': 2,  # Minimal duration of malfunction
@@ -162,7 +172,8 @@ def test_initial_malfunction():
                         1. / 2.: 0.,  # Fast freight train
                         1. / 3.: 0.,  # Slow commuter train
                         1. / 4.: 0.}  # Slow freight train
-
+    np.random.seed(5)
+    random.seed(0)
     env = RailEnv(width=25,
                   height=30,
                   rail_generator=sparse_rail_generator(num_cities=5,
@@ -226,15 +237,15 @@ def test_initial_malfunction():
             )
         ],
         speed=env.agents[0].speed_data['speed'],
-        target=env.agents[0].target
+        target=env.agents[0].target,
+        initial_position=(28, 5),
+        initial_direction=Grid4TransitionsEnum.EAST,
     )
+
     run_replay_config(env, [replay_config])
 
 
 def test_initial_malfunction_stop_moving():
-    random.seed(0)
-    np.random.seed(0)
-
     stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                        'malfunction_rate': 70,  # Rate of malfunction occurence
                        'min_duration': 2,  # Minimal duration of malfunction
@@ -269,19 +280,21 @@ def test_initial_malfunction_stop_moving():
     replay_config = ReplayConfig(
         replay=[
             Replay(
-                position=(28, 5),
+                position=None,
                 direction=Grid4TransitionsEnum.EAST,
-                action=RailEnvActions.DO_NOTHING,
+                action=RailEnvActions.MOVE_FORWARD,
                 set_malfunction=3,
                 malfunction=3,
-                reward=env.step_penalty  # full step penalty when stopped
+                reward=env.step_penalty,  # full step penalty when stopped
+                status=RailAgentStatus.READY_TO_DEPART
             ),
             Replay(
                 position=(28, 5),
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.DO_NOTHING,
                 malfunction=2,
-                reward=env.step_penalty  # full step penalty when stopped
+                reward=env.step_penalty,  # full step penalty when stopped
+                status=RailAgentStatus.ACTIVE
             ),
             # malfunction stops in the next step and we're still at the beginning of the cell
             # --> if we take action STOP_MOVING, agent should restart without moving
@@ -291,7 +304,8 @@ def test_initial_malfunction_stop_moving():
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.STOP_MOVING,
                 malfunction=1,
-                reward=env.step_penalty  # full step penalty while stopped
+                reward=env.step_penalty,  # full step penalty while stopped
+                status=RailAgentStatus.ACTIVE
             ),
             # we have stopped and do nothing --> should stand still
             Replay(
@@ -299,7 +313,8 @@ def test_initial_malfunction_stop_moving():
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.DO_NOTHING,
                 malfunction=0,
-                reward=env.step_penalty  # full step penalty while stopped
+                reward=env.step_penalty,  # full step penalty while stopped
+                status=RailAgentStatus.ACTIVE
             ),
             # we start to move forward --> should go to next cell now
             Replay(
@@ -307,21 +322,24 @@ def test_initial_malfunction_stop_moving():
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.MOVE_FORWARD,
                 malfunction=0,
-                reward=env.start_penalty + env.step_penalty * 1.0  # full step penalty while stopped
+                reward=env.start_penalty + env.step_penalty * 1.0,  # full step penalty while stopped
+                status=RailAgentStatus.ACTIVE
             ),
             Replay(
                 position=(28, 4),
                 direction=Grid4TransitionsEnum.WEST,
                 action=RailEnvActions.MOVE_FORWARD,
                 malfunction=0,
-                reward=env.step_penalty * 1.0  # full step penalty while stopped
+                reward=env.step_penalty * 1.0,  # full step penalty while stopped
+                status=RailAgentStatus.ACTIVE
             )
         ],
         speed=env.agents[0].speed_data['speed'],
-        target=env.agents[0].target
+        target=env.agents[0].target,
+        initial_position=(28, 5),
+        initial_direction=Grid4TransitionsEnum.EAST,
     )
-
-    run_replay_config(env, [replay_config])
+    run_replay_config(env, [replay_config], activate_agents=False)
 
 
 def test_initial_malfunction_do_nothing():
@@ -360,20 +378,23 @@ def test_initial_malfunction_do_nothing():
                   )
     set_penalties_for_replay(env)
     replay_config = ReplayConfig(
-        replay=[Replay(
-            position=(28, 5),
-            direction=Grid4TransitionsEnum.EAST,
-            action=RailEnvActions.DO_NOTHING,
-            set_malfunction=3,
-            malfunction=3,
-            reward=env.step_penalty  # full step penalty while malfunctioning
-        ),
+        replay=[
+            Replay(
+                position=None,
+                direction=Grid4TransitionsEnum.EAST,
+                action=RailEnvActions.MOVE_FORWARD,
+                set_malfunction=3,
+                malfunction=3,
+                reward=env.step_penalty,  # full step penalty while malfunctioning
+                status=RailAgentStatus.READY_TO_DEPART
+            ),
             Replay(
                 position=(28, 5),
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.DO_NOTHING,
                 malfunction=2,
-                reward=env.step_penalty  # full step penalty while malfunctioning
+                reward=env.step_penalty,  # full step penalty while malfunctioning
+                status=RailAgentStatus.ACTIVE
             ),
             # malfunction stops in the next step and we're still at the beginning of the cell
             # --> if we take action DO_NOTHING, agent should restart without moving
@@ -383,7 +404,8 @@ def test_initial_malfunction_do_nothing():
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.DO_NOTHING,
                 malfunction=1,
-                reward=env.step_penalty  # full step penalty while stopped
+                reward=env.step_penalty,  # full step penalty while stopped
+                status=RailAgentStatus.ACTIVE
             ),
             # we haven't started moving yet --> stay here
             Replay(
@@ -391,7 +413,8 @@ def test_initial_malfunction_do_nothing():
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.DO_NOTHING,
                 malfunction=0,
-                reward=env.step_penalty  # full step penalty while stopped
+                reward=env.step_penalty,  # full step penalty while stopped
+                status=RailAgentStatus.ACTIVE
             ),
             # we start to move forward --> should go to next cell now
             Replay(
@@ -399,21 +422,25 @@ def test_initial_malfunction_do_nothing():
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.MOVE_FORWARD,
                 malfunction=0,
-                reward=env.start_penalty + env.step_penalty * 1.0  # start penalty + step penalty for speed 1.0
+                reward=env.start_penalty + env.step_penalty * 1.0,  # start penalty + step penalty for speed 1.0
+                status=RailAgentStatus.ACTIVE
             ),
             Replay(
                 position=(28, 4),
                 direction=Grid4TransitionsEnum.WEST,
                 action=RailEnvActions.MOVE_FORWARD,
                 malfunction=0,
-                reward=env.step_penalty * 1.0  # step penalty for speed 1.0
+                reward=env.step_penalty * 1.0,  # step penalty for speed 1.0
+                status=RailAgentStatus.ACTIVE
             )
         ],
         speed=env.agents[0].speed_data['speed'],
-        target=env.agents[0].target
+        target=env.agents[0].target,
+        initial_position=(28, 5),
+        initial_direction=Grid4TransitionsEnum.EAST,
     )
 
-    run_replay_config(env, [replay_config])
+    run_replay_config(env, [replay_config], activate_agents=False)
 
 
 def test_initial_nextmalfunction_not_below_zero():
diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py
index 7213560f9e9873ea4488b96d30223bab8128b37b..f29629ab7c7aabb9ca2989b02a3604239d9e6143 100644
--- a/tests/test_global_observation.py
+++ b/tests/test_global_observation.py
@@ -1,7 +1,7 @@
 import numpy as np
 
 from flatland.envs.observations import GlobalObsForRailEnv
-from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_env import RailEnv, RailEnvActions
 from flatland.envs.rail_generators import sparse_rail_generator
 from flatland.envs.schedule_generators import sparse_schedule_generator
 
@@ -40,7 +40,7 @@ def test_get_global_observation():
                   number_of_agents=number_of_agents, stochastic_data=stochastic_data,  # Malfunction data generator
                   obs_builder_object=GlobalObsForRailEnv())
 
-    obs, all_rewards, done, _ = env.step({0: 0})
+    obs, all_rewards, done, _ = env.step({i: RailEnvActions.MOVE_FORWARD for i in range(number_of_agents)})
 
     for i in range(len(env.agents)):
         obs_agents_state = obs[i][1]
diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py
index b0f274ba4c4b5453140fcc50bc6137e39e8e4f04..3cd0a4c1d9812c728089255eb1150111b783a463 100644
--- a/tests/test_multi_speed.py
+++ b/tests/test_multi_speed.py
@@ -63,7 +63,8 @@ def test_multi_speed_init():
 
     # Set all the different speeds
     # Reset environment and get initial observations for all agents
-    env.reset()
+    env.reset(False, False, True)
+
     # Here you can also further enhance the provided observation by means of normalization
     # See training navigation example in the baseline repository
     old_pos = []
@@ -188,7 +189,9 @@ def test_multispeed_actions_no_malfunction_no_blocking():
             ),
         ],
         target=(3, 0),  # west dead-end
-        speed=0.5
+        speed=0.5,
+        initial_position=(3, 9),  # east dead-end
+        initial_direction=Grid4TransitionsEnum.EAST,
     )
 
     run_replay_config(env, [test_config])
@@ -285,7 +288,10 @@ def test_multispeed_actions_no_malfunction_blocking():
                 )
             ],
             target=(3, 0),  # west dead-end
-            speed=1 / 3),
+            speed=1 / 3,
+            initial_position=(3, 8),
+            initial_direction=Grid4TransitionsEnum.WEST,
+        ),
         ReplayConfig(
             replay=[
                 Replay(
@@ -369,7 +375,9 @@ def test_multispeed_actions_no_malfunction_blocking():
                 ),
             ],
             target=(3, 0),  # west dead-end
-            speed=0.5
+            speed=0.5,
+            initial_position=(3, 9),  # east dead-end
+            initial_direction=Grid4TransitionsEnum.EAST,
         )
 
     ]
@@ -505,7 +513,9 @@ def test_multispeed_actions_malfunction_no_blocking():
 
         ],
         target=(3, 0),  # west dead-end
-        speed=0.5
+        speed=0.5,
+        initial_position=(3, 9),  # east dead-end
+        initial_direction=Grid4TransitionsEnum.EAST,
     )
     run_replay_config(env, [test_config])
 
@@ -587,7 +597,9 @@ def test_multispeed_actions_no_malfunction_invalid_actions():
 
         ],
         target=(3, 0),  # west dead-end
-        speed=0.5
+        speed=0.5,
+        initial_position=(3, 9),  # east dead-end
+        initial_direction=Grid4TransitionsEnum.EAST,
     )
 
     run_replay_config(env, [test_config])
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 903120d868aa65833e7c2393ddfcc821c26da4f6..f5d1cd5c957d1e81ca2f2465fb8265c6b2795342 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -5,7 +5,7 @@ import numpy as np
 from attr import attrs, attrib
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
-from flatland.envs.agent_utils import EnvAgent
+from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
 from flatland.envs.rail_env import RailEnvActions, RailEnv
 from flatland.utils.rendertools import RenderTool
 
@@ -18,6 +18,7 @@ class Replay(object):
     malfunction = attrib(default=0, type=int)
     set_malfunction = attrib(default=None, type=Optional[int])
     reward = attrib(default=None, type=Optional[float])
+    status = attrib(default=None, type=Optional[RailAgentStatus])
 
 
 @attrs
@@ -25,6 +26,8 @@ class ReplayConfig(object):
     replay = attrib(type=List[Replay])
     target = attrib(type=Tuple[int, int])
     speed = attrib(type=float)
+    initial_position = attrib(type=Tuple[int, int])
+    initial_direction = attrib(type=Grid4TransitionsEnum)
 
 
 # ensure that env is working correctly with start/stop/invalidaction penalty different from 0
@@ -35,7 +38,7 @@ def set_penalties_for_replay(env: RailEnv):
     env.invalid_action_penalty = -29
 
 
-def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False):
+def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False, activate_agents=True):
     """
     Runs the replay configs and checks assertions.
 
@@ -47,10 +50,12 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
     - position, direction before step are verified
     - optionally, set_malfunction is applied
     - malfunction is verified
+    - status is verified (optionally)
 
     *After each step*
     - reward is verified after step
 
+
     Parameters
     ----------
     env
@@ -67,18 +72,20 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
     for step in range(len(test_configs[0].replay)):
         if step == 0:
             for a, test_config in enumerate(test_configs):
-                agent: EnvAgent = env.agents[a]
-                replay = test_config.replay[0]
+                agent: EnvAgent = env.agents_static[a]
                 # set the initial position
-                agent.position = replay.position
-                agent.direction = replay.direction
+                agent.initial_position = test_config.initial_position
+                agent.direction = test_config.initial_direction
                 agent.target = test_config.target
                 agent.speed_data['speed'] = test_config.speed
+            env.reset(False, False, activate_agents)
 
         def _assert(a, actual, expected, msg):
-            assert np.allclose(actual, expected), "[{}] agent {} {}:  actual={}, expected={}".format(step, a, msg,
-                                                                                                    actual,
-                                                                                                    expected)
+            print("[{}] verifying {} on agent {}: actual={}, expected={}".format(step, msg, a, actual, expected))
+            assert (actual == expected) or (
+                np.allclose(actual, expected)), "[{}] agent {} {}:  actual={}, expected={}".format(step, a, msg,
+                                                                                                   actual,
+                                                                                                   expected)
 
         action_dict = {}
 
@@ -88,26 +95,29 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
 
             _assert(a, agent.position, replay.position, 'position')
             _assert(a, agent.direction, replay.direction, 'direction')
+            if replay.status is not None:
+                _assert(a, agent.status, replay.status, 'status')
 
             if replay.action is not None:
-                assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format(
+                assert info_dict['action_required'][
+                           a] == True or agent.status == RailAgentStatus.READY_TO_DEPART, "[{}] agent {} expecting action_required={} or agent status READY_TO_DEPART".format(
                     step, a, True)
                 action_dict[a] = replay.action
             else:
-                assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format(
-                    step, a, False)
+                assert info_dict['action_required'][
+                           a] == False, "[{}] agent {} expecting action_required={}, but found {}".format(
+                    step, a, False, info_dict['action_required'][a])
 
             if replay.set_malfunction is not None:
                 agent.malfunction_data['malfunction'] = replay.set_malfunction
                 agent.malfunction_data['moving_before_malfunction'] = agent.moving
             _assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
-
+        print(step)
         _, rewards_dict, _, info_dict = env.step(action_dict)
         if rendering:
             renderer.render_env(show=True, show_observations=True)
 
         for a, test_config in enumerate(test_configs):
             replay = test_config.replay[step]
-            _assert(a, rewards_dict[a], replay.reward, 'reward')
-
 
+            _assert(a, rewards_dict[a], replay.reward, 'reward')
diff --git a/tests/tests_generators.py b/tests/tests_generators.py
index 4c925789e6560077d637e2a594c736df8850d00a..9bfe3a47d0871cbfbbabb4e35eac0bd6ecaaedf0 100644
--- a/tests/tests_generators.py
+++ b/tests/tests_generators.py
@@ -102,6 +102,7 @@ def test_rail_from_grid_transition_map():
                   schedule_generator=random_schedule_generator(),
                   number_of_agents=n_agents
                   )
+    env.reset(False, False, True)
     nr_rail_elements = np.count_nonzero(env.rail.grid)
 
     # Check if the number of non-empty rail cells is ok