diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index e6c418827d93ae4e647bb3f2cfbcec81d9a9f059..8b4f43fe3535c7e85f6b954b5c58d62acceecc23 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -89,6 +89,15 @@ class RailEnv(Environment):
     For Round 2, they will be passed to the constructor as arguments, to allow for more flexibility.
 
     """
+    alpha = 1.0
+    beta = 1.0
+    # Epsilon to avoid rounding errors
+    epsilon = 0.01
+    invalid_action_penalty = 0  # previously -2; GIACOMO: we decided that invalid actions will carry no penalty
+    step_penalty = -1 * alpha
+    global_reward = 1 * beta
+    stop_penalty = 0  # penalty for stopping a moving agent
+    start_penalty = 0  # penalty for starting a stopped agent
 
     def __init__(self,
                  width,
@@ -252,7 +261,7 @@ class RailEnv(Environment):
 
             agent.malfunction_data['malfunction'] = 0
 
-            self._agent_malfunction(agent)
+            self._agent_malfunction(i_agent, RailEnvActions.DO_NOTHING)
 
         self.num_resets += 1
         self._elapsed_steps = 0
@@ -267,7 +276,9 @@ class RailEnv(Environment):
         # Return the new observation vectors for each agent
         return self._get_observations()
 
-    def _agent_malfunction(self, agent):
+    def _agent_malfunction(self, i_agent, action) -> bool:
+        agent = self.agents[i_agent]
+
         # Decrease counter for next event
         if agent.malfunction_data['malfunction_rate'] > 0:
             agent.malfunction_data['next_malfunction'] -= 1
@@ -291,31 +302,19 @@ class RailEnv(Environment):
                                                      self.max_number_of_steps_broken + 1) + 1
                 agent.malfunction_data['malfunction'] = num_broken_steps
 
-
-
+                return True
+        return False
 
     def step(self, action_dict_):
         self._elapsed_steps += 1
 
-        action_dict = action_dict_.copy()
-
-        alpha = 1.0
-        beta = 1.0
-        # Epsilon to avoid rounding errors
-        epsilon = 0.01
-        invalid_action_penalty = 0  # previously -2; GIACOMO: we decided that invalid actions will carry no penalty
-        step_penalty = -1 * alpha
-        global_reward = 1 * beta
-        stop_penalty = 0  # penalty for stopping a moving agent
-        start_penalty = 0  # penalty for starting a stopped agent
-
         # Reset the step rewards
         self.rewards_dict = dict()
         for i_agent in range(self.get_num_agents()):
             self.rewards_dict[i_agent] = 0
 
         if self.dones["__all__"]:
-            self.rewards_dict = {i: r + global_reward for i, r in self.rewards_dict.items()}
+            self.rewards_dict = {i: r + self.global_reward for i, r in self.rewards_dict.items()}
             info_dict = {
                 'action_required': {i: False for i in range(self.get_num_agents())},
                 'malfunction': {i: 0 for i in range(self.get_num_agents())},
@@ -324,26 +323,71 @@ class RailEnv(Environment):
             return self._get_observations(), self.rewards_dict, self.dones, info_dict
 
         for i_agent in range(self.get_num_agents()):
-            agent = self.agents[i_agent]
-            agent.old_direction = agent.direction
-            agent.old_position = agent.position
-
 
             if self.dones[i_agent]:  # this agent has already completed...
                 continue
 
-            # No action has been supplied for this agent
-            if i_agent not in action_dict:
-                action_dict[i_agent] = RailEnvActions.DO_NOTHING
+            agent = self.agents[i_agent]
+            agent.old_direction = agent.direction
+            agent.old_position = agent.position
 
+            # No action has been supplied for this agent -> set DO_NOTHING as default
+            if i_agent not in action_dict_:
+                action = RailEnvActions.DO_NOTHING
+            else:
+                action = action_dict_[i_agent]
 
-            if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions):
-                print('ERROR: illegal action=', action_dict[i_agent],
+            if action < 0 or action > len(RailEnvActions):
+                print('ERROR: illegal action=', action,
                       'for agent with index=', i_agent,
                       '"DO NOTHING" will be executed instead')
-                action_dict[i_agent] = RailEnvActions.DO_NOTHING
+                action = RailEnvActions.DO_NOTHING
+
+            # Check if agent breaks at this step
+            malfunction = self._agent_malfunction(i_agent, action)
+
+            # if we're at the beginning of the cell, store the action
+            # As long as we're broken down at the beginning of the cell, we can choose other actions!
+            # This is a design choice made by Erik and Christian.
+
+            # TODO refactor!!!
+            # If the agent can make an action
+            if agent.speed_data['position_fraction'] == 0.0:
+                if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING:
+                    cell_free, new_cell_valid, new_direction, new_position, transition_valid = \
+                        self._check_action_on_agent(action, agent)
 
-            action = action_dict[i_agent]
+                    if all([new_cell_valid, transition_valid]):
+                        agent.speed_data['transition_action_on_cellexit'] = action
+                    else:
+                        # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
+                        # try to keep moving forward!
+                        if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT):
+                            cell_free, new_cell_valid, new_direction, new_position, transition_valid = \
+                                self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)
+
+                            if all([new_cell_valid, transition_valid]):
+                                agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
+                            else:
+                                # If the agent cannot move due to an invalid transition, we set its state to not moving
+                                self.rewards_dict[i_agent] += self.invalid_action_penalty
+                                self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
+                                self.rewards_dict[i_agent] += self.stop_penalty
+                                agent.moving = False
+                                action = RailEnvActions.DO_NOTHING
+
+                        else:
+                            # If the agent cannot move due to an invalid transition, we set its state to not moving
+                            self.rewards_dict[i_agent] += self.invalid_action_penalty
+                            self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
+                            self.rewards_dict[i_agent] += self.stop_penalty
+                            agent.moving = False
+                            action = RailEnvActions.DO_NOTHING
+                else:
+                    agent.speed_data['transition_action_on_cellexit'] = action
+
+            if malfunction:
+                continue
 
             # The train is broken
             if agent.malfunction_data['malfunction'] > 0:
@@ -352,37 +396,31 @@ class RailEnv(Environment):
                 if agent.malfunction_data['malfunction'] < 2:
                     agent.malfunction_data['malfunction'] -= 1
                     self.agents[i_agent].moving = True
-                    action_dict[i_agent] = RailEnvActions.DO_NOTHING
+                    action = RailEnvActions.DO_NOTHING
 
                 else:
                     agent.malfunction_data['malfunction'] -= 1
 
                     # Broken agents are stopped
-                    self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
+                    self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
                     self.agents[i_agent].moving = False
-                    action_dict[i_agent] = RailEnvActions.DO_NOTHING
 
                     # Nothing left to do with broken agent
                     continue
 
-            # Check if agent breaks at this step
-            self._agent_malfunction(agent)
-
-
             if action == RailEnvActions.DO_NOTHING and agent.moving:
                 # Keep moving
                 action = RailEnvActions.MOVE_FORWARD
 
-            if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data[
-                'position_fraction'] <= epsilon:
+            if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] == 0.0:
                 # Only allow halting an agent on entering new cells.
                 agent.moving = False
-                self.rewards_dict[i_agent] += stop_penalty
+                self.rewards_dict[i_agent] += self.stop_penalty
 
             if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
                 # Allow agent to start with any forward or direction action
                 agent.moving = True
-                self.rewards_dict[i_agent] += start_penalty
+                self.rewards_dict[i_agent] += self.start_penalty
 
             # Now perform a movement.
             # If the agent is in an initial position within a new cell (agent.speed_data['position_fraction']<eps)
@@ -394,70 +432,36 @@ class RailEnv(Environment):
             # If the new position fraction is >= 1, reset to 0, and perform the stored
             #   transition_action_on_cellexit
 
-            # If the agent can make an action
-            action_selected = False
-            if agent.speed_data['position_fraction'] <= epsilon:
-                if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING:
-                    cell_free, new_cell_valid, new_direction, new_position, transition_valid = \
-                        self._check_action_on_agent(action, agent)
-
-                    if all([new_cell_valid, transition_valid]):
-                        agent.speed_data['transition_action_on_cellexit'] = action
-                        action_selected = True
+            if agent.moving:
 
-                    else:
-                        # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
-                        # try to keep moving forward!
-                        if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT) and agent.moving:
-                            cell_free, new_cell_valid, new_direction, new_position, transition_valid = \
-                                self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)
-
-                            if all([new_cell_valid, transition_valid]):
-                                agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
-                                action_selected = True
-
-                            else:
-                                # TODO: an invalid action was chosen after entering the cell. The agent cannot move.
-                                self.rewards_dict[i_agent] += invalid_action_penalty
-                                self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
-                                self.rewards_dict[i_agent] += stop_penalty
-                                agent.moving = False
-                                continue
-                        else:
-                            # TODO: an invalid action was chosen after entering the cell. The agent cannot move.
-                            self.rewards_dict[i_agent] += invalid_action_penalty
-                            self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
-                            self.rewards_dict[i_agent] += stop_penalty
-                            agent.moving = False
-                            continue
-
-            if agent.moving and (action_selected or agent.speed_data['position_fraction'] > 0.0):
                 agent.speed_data['position_fraction'] += agent.speed_data['speed']
 
-            if agent.speed_data['position_fraction'] >= 1.0:
-
-                # Perform stored action to transition to the next cell as soon as cell is free
-                cell_free, new_cell_valid, new_direction, new_position, transition_valid = \
-                    self._check_action_on_agent(agent.speed_data['transition_action_on_cellexit'], agent)
-
-                if all([new_cell_valid, transition_valid, cell_free]) and agent.malfunction_data['malfunction'] == 0:
-                    agent.position = new_position
-                    agent.direction = new_direction
-                    agent.speed_data['position_fraction'] = 0.0
-                elif not transition_valid or not new_cell_valid:
-                    # If the agent cannot move due to an invalid transition, we set its state to not moving
-                    agent.moving = False
+                if agent.speed_data['position_fraction'] >= 1.0:
+                    # Perform stored action to transition to the next cell as soon as cell is free
+                    # Notice that we've already check new_cell_valid and transition valid when we stored the action,
+                    # so we only have to check cell_free now!
+                    if agent.speed_data['transition_action_on_cellexit'] in [RailEnvActions.DO_NOTHING,
+                                                                             RailEnvActions.STOP_MOVING]:
+                        agent.speed_data['position_fraction'] = 0.0
+                    else:
+                        cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent(
+                            agent.speed_data['transition_action_on_cellexit'], agent)
+                        assert cell_free == all([cell_free, new_cell_valid, transition_valid])
+                        if cell_free:
+                            agent.position = new_position
+                            agent.direction = new_direction
+                            agent.speed_data['position_fraction'] = 0.0
 
             if np.equal(agent.position, agent.target).all():
                 self.dones[i_agent] = True
                 agent.moving = False
             else:
-                self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
+                self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
 
         # Check for end of episode + add global reward to all rewards!
         if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
             self.dones["__all__"] = True
-            self.rewards_dict = {i: 0 * r + global_reward for i, r in self.rewards_dict.items()}
+            self.rewards_dict = {i: 0 * r + self.global_reward for i, r in self.rewards_dict.items()}
 
         if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps):
             self.dones["__all__"] = True
@@ -481,6 +485,7 @@ class RailEnv(Environment):
         return self._get_observations(), self.rewards_dict, self.dones, info_dict
 
     def _check_action_on_agent(self, action, agent):
+
         # compute number of possible transitions in the current
         # cell used to check for invalid actions
         new_direction, transition_valid = self.check_action(agent, action)