diff --git a/flatland/action_plan/action_plan.py b/flatland/action_plan/action_plan.py
index b1a56d81839bff62f13a27753a935a19a8d05fe9..96a441299fd68b9d8f0e51e6d3e2b543ec15ba57 100644
--- a/flatland/action_plan/action_plan.py
+++ b/flatland/action_plan/action_plan.py
@@ -150,7 +150,7 @@ class ControllerFromTrainruns():
     def _create_action_plan_for_agent(self, agent_id, trainrun) -> ActionPlan:
         action_plan = []
         agent = self.env.agents[agent_id]
-        minimum_cell_time = agent.speed_counter.max_count
+        minimum_cell_time = agent.speed_counter.max_count + 1
         for path_loop, trainrun_waypoint in enumerate(trainrun):
             trainrun_waypoint: TrainrunWaypoint = trainrun_waypoint
 
diff --git a/flatland/action_plan/action_plan_player.py b/flatland/action_plan/action_plan_player.py
index 074e5590185ff601f9c038e9df4c23fd2f84c455..f9b82ba967392816319a8203b136524a1abba0fa 100644
--- a/flatland/action_plan/action_plan_player.py
+++ b/flatland/action_plan/action_plan_player.py
@@ -30,10 +30,7 @@ class ControllerFromTrainrunsReplayer():
                 assert agent.position == waypoint.position, \
                     "before {}, agent {} at {}, expected {}".format(i, agent_id, agent.position,
                                                                     waypoint.position)
-                if agent_id == 1:
-                    print(env._elapsed_steps, agent.position, agent.state, agent.speed_counter)
             actions = ctl.act(i)
-            print("actions for {}: {}".format(i, actions))
 
             obs, all_rewards, done, _ = env.step(actions)
 
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 456d56a0c58fdbaefa5a2ff4c5e938b74618e1c1..0b5f2a845d525f36456ce3c770fe4453d2c8a0e5 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -221,7 +221,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                                                             agent.direction)],
                                                        num_agents_same_direction=0, num_agents_opposite_direction=0,
                                                        num_agents_malfunctioning=agent.malfunction_data['malfunction'],
-                                                       speed_min_fractional=agent.speed_counter.speed
+                                                       speed_min_fractional=agent.speed_counter.speed,
                                                        num_agents_ready_to_depart=0,
                                                        childs={})
         #print("root node type:", type(root_node_observation))
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 2915e9be2b1d9537631f0639a6f20a9f05955d17..6a766f35be9d26f3d40623a7ba9c314f410751b3 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -366,9 +366,10 @@ class RailEnv(Environment):
             new_position = get_new_position(position, new_direction)
         else:
             new_position, new_direction = position, direction
-        return new_position, direction
+        return new_position, new_direction
     
     def generate_state_transition_signals(self, agent, preprocessed_action, movement_allowed):
+        """ Generate State Transitions Signals used in the state machine """
         st_signals = StateTransitionSignals()
         
         # Malfunction onset - Malfunction starts
@@ -442,9 +443,8 @@ class RailEnv(Environment):
         return action
     
     def clear_rewards_dict(self):
-        """ Reset the step rewards """
-
-        self.rewards_dict = dict()
+        """ Reset the rewards dictionary """
+        self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))}
 
     def get_info_dict(self): # TODO Important : Update this
         info_dict = {
@@ -456,6 +456,22 @@ class RailEnv(Environment):
             'state': {i: agent.state for i, agent in enumerate(self.agents)}
         }
         return info_dict
+    
+    def update_step_rewards(self, i_agent):
+        pass
+
+    def end_of_episode_update(self, have_all_agents_ended):
+        if have_all_agents_ended or \
+           ( (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps)):
+
+            for i_agent, agent in enumerate(self.agents):
+                
+                reward = self._handle_end_reward(agent)
+                self.rewards_dict[i_agent] += reward
+                
+                self.dones[i_agent] = True
+
+            self.dones["__all__"] = True
 
     def step(self, action_dict_: Dict[int, RailEnvActions]):
         """
@@ -520,6 +536,8 @@ class RailEnv(Environment):
             i_agent = agent.handle
             agent_transition_data = temp_transition_data[i_agent]
 
+            old_position = agent.position
+
             ## Update positions
             if agent.malfunction_handler.in_malfunction:
                 movement_allowed = False
@@ -544,30 +562,18 @@ class RailEnv(Environment):
             have_all_agents_ended &= (agent.state == TrainState.DONE)
 
             ## Update rewards
-            # self.update_rewards(i_agent, agent, rail) # TODO : Step Rewards
+            self.update_step_rewards(i_agent)
 
             ## Update counters (malfunction and speed)
-            agent.speed_counter.update_counter(agent.state)
+            agent.speed_counter.update_counter(agent.state, old_position)
             agent.malfunction_handler.update_counter()
 
             # Clear old action when starting in new cell
             if agent.speed_counter.is_cell_entry:
                 agent.action_saver.clear_saved_action()
-
-
-        self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))}
         
-        if ((self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps)) \
-            or have_all_agents_ended :
-            
-            for i_agent, agent in enumerate(self.agents):
-                
-                reward = self._handle_end_reward(agent)
-                self.rewards_dict[i_agent] += reward
-                
-                self.dones[i_agent] = True
-
-            self.dones["__all__"] = True
+        # Check if episode has ended and update rewards and dones
+        self.end_of_episode_update(have_all_agents_ended)
 
         return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict() 
 
diff --git a/flatland/envs/step_utils/speed_counter.py b/flatland/envs/step_utils/speed_counter.py
index 5bde9c20f98b1b7ed26ad4a8ba3d5791786bd84f..272087817439a659298fa12f71aaa7c982b91bf5 100644
--- a/flatland/envs/step_utils/speed_counter.py
+++ b/flatland/envs/step_utils/speed_counter.py
@@ -4,12 +4,13 @@ from flatland.envs.step_utils.states import TrainState
 class SpeedCounter:
     def __init__(self, speed):
         self.speed = speed
-        self.max_count = int(1/speed)
+        self.max_count = int(1/speed) - 1
 
-    def update_counter(self, state):
-        if state == TrainState.MOVING:
+    def update_counter(self, state, old_position):
+        # When coming onto the map, do no update speed counter
+        if state == TrainState.MOVING and old_position is not None:
             self.counter += 1
-            self.counter = self.counter % self.max_count
+            self.counter = self.counter % (self.max_count + 1)
 
     def __repr__(self):
         return f"speed: {self.speed} \
@@ -27,5 +28,5 @@ class SpeedCounter:
 
     @property
     def is_cell_exit(self):
-        return self.counter == self.max_count - 1
+        return self.counter == self.max_count
 
diff --git a/tests/test_action_plan.py b/tests/test_action_plan.py
index 2b062c4e5a892322bcf8c86e3be66e433254b346..9be4fdf6410b6f63455c6df58da8121012778b85 100644
--- a/tests/test_action_plan.py
+++ b/tests/test_action_plan.py
@@ -9,6 +9,7 @@ from flatland.envs.rail_trainrun_data_structures import Waypoint
 from flatland.envs.line_generators import sparse_line_generator
 from flatland.utils.rendertools import RenderTool, AgentRenderVariant
 from flatland.utils.simple_rail import make_simple_rail
+from flatland.envs.step_utils.speed_counter import SpeedCounter
 
 
 def test_action_plan(rendering: bool = False):
@@ -29,7 +30,7 @@ def test_action_plan(rendering: bool = False):
     env.agents[1].initial_position = (3, 8)
     env.agents[1].initial_direction = Grid4TransitionsEnum.WEST
     env.agents[1].target = (0, 3)
-    env.agents[1].speed_data['speed'] = 0.5  # two
+    env.agents[1].speed_counter = SpeedCounter(speed=0.5)
     env.reset(False, False)
     for handle, agent in enumerate(env.agents):
         print("[{}] {} -> {}".format(handle, agent.initial_position, agent.target))