diff --git a/flatland/action_plan/action_plan.py b/flatland/action_plan/action_plan.py
index 249c4c0ee12bb8a79c06842a59108bd4f3ce6c5c..b1a56d81839bff62f13a27753a935a19a8d05fe9 100644
--- a/flatland/action_plan/action_plan.py
+++ b/flatland/action_plan/action_plan.py
@@ -150,7 +150,7 @@ class ControllerFromTrainruns():
     def _create_action_plan_for_agent(self, agent_id, trainrun) -> ActionPlan:
         action_plan = []
         agent = self.env.agents[agent_id]
-        minimum_cell_time = int(np.ceil(1.0 / agent.speed_data['speed']))
+        minimum_cell_time = agent.speed_counter.max_count
         for path_loop, trainrun_waypoint in enumerate(trainrun):
             trainrun_waypoint: TrainrunWaypoint = trainrun_waypoint
 
diff --git a/flatland/action_plan/action_plan_player.py b/flatland/action_plan/action_plan_player.py
index f3deee133d8c99ffc5993005f1500e227be87b7e..074e5590185ff601f9c038e9df4c23fd2f84c455 100644
--- a/flatland/action_plan/action_plan_player.py
+++ b/flatland/action_plan/action_plan_player.py
@@ -30,6 +30,8 @@ class ControllerFromTrainrunsReplayer():
                 assert agent.position == waypoint.position, \
                     "before {}, agent {} at {}, expected {}".format(i, agent_id, agent.position,
                                                                     waypoint.position)
+                if agent_id == 1:
+                    print(env._elapsed_steps, agent.position, agent.state, agent.speed_counter)
             actions = ctl.act(i)
             print("actions for {}: {}".format(i, actions))
 
diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index 4dee6dde0f5a938d81e5cd970332223a9f6b841b..6dff63e18e505d6ff7cb8280b53f63178c3f1921 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -1,5 +1,6 @@
 from flatland.envs.rail_trainrun_data_structures import Waypoint
 import numpy as np
+import warnings
 
 from typing import Tuple, Optional, NamedTuple, List
 
@@ -21,7 +22,6 @@ Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
                              ('moving', bool),
                              ('earliest_departure', int),
                              ('latest_arrival', int),
-                             ('speed_data', dict),
                              ('malfunction_data', dict),
                              ('handle', int),
                              ('position', Tuple[int, int]),
@@ -49,13 +49,6 @@ class EnvAgent:
     earliest_departure = attrib(default=None, type=int)  # default None during _from_line()
     latest_arrival = attrib(default=None, type=int)  # default None during _from_line()
 
-    # speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0,
-    # after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous
-    # cell if speed=1, as default)
-    # N.B. we need to use factory since default arguments are not recreated on each call!
-    speed_data = attrib(
-        default=Factory(lambda: dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0})))
-
     # if broken>0, the agent's actions are ignored for 'broken' steps
     # number of time the agent had to stop, since the last time it broke down
     malfunction_data = attrib(
@@ -67,7 +60,7 @@ class EnvAgent:
     # INIT TILL HERE IN _from_line()
 
     # Env step facelift
-    speed_counter = attrib(default = None, type=SpeedCounter)
+    speed_counter = attrib(default = Factory(lambda: SpeedCounter(1.0)), type=SpeedCounter)
     action_saver = attrib(default = Factory(lambda: ActionSaver()), type=ActionSaver)
     state_machine = attrib(default= Factory(lambda: TrainStateMachine(initial_state=TrainState.WAITING)) , 
                            type=TrainStateMachine)
@@ -94,10 +87,6 @@ class EnvAgent:
         self.old_direction = None
         self.moving = False
 
-        # Reset agent values for speed
-        self.speed_data['position_fraction'] = 0.
-        self.speed_data['transition_action_on_cellexit'] = 0.
-
         # Reset agent malfunction values
         self.malfunction_data['malfunction'] = 0
         self.malfunction_data['nr_malfunctions'] = 0
@@ -115,7 +104,6 @@ class EnvAgent:
                      moving=self.moving,
                      earliest_departure=self.earliest_departure, 
                      latest_arrival=self.latest_arrival, 
-                     speed_data=self.speed_data,
                      malfunction_data=self.malfunction_data, 
                      handle=self.handle, 
                      state=self.state,
@@ -137,7 +125,7 @@ class EnvAgent:
             distance = len(shortest_path)
         else:
             distance = 0
-        speed = self.speed_data['speed']
+        speed = self.speed_counter.speed
         return int(np.ceil(distance / speed))
 
     def get_time_remaining_until_latest_arrival(self, elapsed_steps: int) -> int:
@@ -161,11 +149,6 @@ class EnvAgent:
         agent_list = []
         for i_agent in range(num_agents):
             speed = line.agent_speeds[i_agent] if line.agent_speeds is not None else 1.0
-
-            speed_data = {'position_fraction': 0.0,
-                           'speed': speed,
-                           'transition_action_on_cellexit': 0
-                          }
             
             if line.agent_malfunction_rates is not None:
                 malfunction_rate = line.agent_malfunction_rates[i_agent]
@@ -177,7 +160,6 @@ class EnvAgent:
                                 'next_malfunction': 0,
                                 'nr_malfunctions': 0
                                }
-            
             agent = EnvAgent(initial_position = line.agent_positions[i_agent],
                             initial_direction = line.agent_directions[i_agent],
                             direction = line.agent_directions[i_agent],
@@ -185,7 +167,6 @@ class EnvAgent:
                             moving = False, 
                             earliest_departure = None,
                             latest_arrival = None,
-                            speed_data = speed_data,
                             malfunction_data = malfunction_data,
                             handle = i_agent,
                             speed_counter = SpeedCounter(speed=speed))
@@ -195,6 +176,7 @@ class EnvAgent:
 
     @classmethod
     def load_legacy_static_agent(cls, static_agents_data: Tuple):
+        raise NotImplementedError("Not implemented for Flatland 3")
         agents = []
         for i, static_agent in enumerate(static_agents_data):
             if len(static_agent) >= 6:
@@ -205,16 +187,35 @@ class EnvAgent:
                 agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
                                 direction=static_agent[1], target=static_agent[2], 
                                 moving=False,
-                                speed_data={"speed":1., "position_fraction":0., "transition_action_on_cell_exit":0.},
                                 malfunction_data={
                                             'malfunction': 0,
                                             'nr_malfunctions': 0,
                                             'moving_before_malfunction': False
                                         },
+                                speed_counter=SpeedCounter(1.0),
                                 handle=i)
             agents.append(agent)
         return agents
     
+    def _set_state(self, state):
+        warnings.warn("Not recommended to set the state with this function unless completely required")
+        self.state_machine.set_state(state)
+    
+    def __str__(self):
+        return f"\n \
+                 handle(agent index): {self.handle} \n \
+                 initial_position: {self.initial_position}   initial_direction: {self.initial_direction} \n \
+                 position: {self.position}  direction: {self.position}  target: {self.target} \n \
+                 earliest_departure: {self.earliest_departure}  latest_arrival: {self.latest_arrival} \n \
+                 state: {str(self.state)} \n \
+                 malfunction_data: {self.malfunction_data} \n \
+                 action_saver: {self.action_saver} \n \
+                 speed_counter: {self.speed_counter}"
+
     @property
     def state(self):
         return self.state_machine.state
+
+
+    
+
diff --git a/flatland/envs/line_generators.py b/flatland/envs/line_generators.py
index 74d01e6f23856e9f14d2fbe70eb2bdbfb85175be..8b412783ca999c5383e102804928888d43aee32a 100644
--- a/flatland/envs/line_generators.py
+++ b/flatland/envs/line_generators.py
@@ -189,7 +189,7 @@ def line_from_file(filename, load_from_package=None) -> LineGenerator:
         #agents_direction = [a.direction for a in agents]
         agents_direction = [a.initial_direction for a in agents]
         agents_target = [a.target for a in agents]
-        agents_speed = [a.speed_data['speed'] for a in agents]
+        agents_speed = [a.speed_counter.speed for a in agents]
 
         # Malfunctions from here are not used.  They have their own generator.
         #agents_malfunction = [a.malfunction_data['malfunction_rate'] for a in agents]
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 1fc0a2e52faf3b228b46b2fd896852ba4c411f26..456d56a0c58fdbaefa5a2ff4c5e938b74618e1c1 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -98,7 +98,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                 _agent.position:
                 self.location_has_agent[tuple(_agent.position)] = 1
                 self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction
-                self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed']
+                self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_counter.speed
                 self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data[
                     'malfunction']
 
@@ -221,7 +221,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                                                             agent.direction)],
                                                        num_agents_same_direction=0, num_agents_opposite_direction=0,
                                                        num_agents_malfunctioning=agent.malfunction_data['malfunction'],
-                                                       speed_min_fractional=agent.speed_data['speed'],
+                                                       speed_min_fractional=agent.speed_counter.speed
                                                        num_agents_ready_to_depart=0,
                                                        childs={})
         #print("root node type:", type(root_node_observation))
@@ -275,7 +275,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         visited = OrderedSet()
         agent = self.env.agents[handle]
-        time_per_cell = np.reciprocal(agent.speed_data["speed"])
+        time_per_cell = np.reciprocal(agent.speed_counter.speed)
         own_target_encountered = np.inf
         other_agent_encountered = np.inf
         other_target_encountered = np.inf
@@ -604,7 +604,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
                 if i != handle:
                     obs_agents_state[other_agent.position][1] = other_agent.direction
                 obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction']
-                obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed']
+                obs_agents_state[other_agent.position][3] = other_agent.speed_counter.speed
             # fifth channel: all ready to depart on this position
             if other_agent.state.is_off_map_state():
                 obs_agents_state[other_agent.initial_position][4] += 1
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index 4b097083fa98c5121df644b8f9d34b27fdc34a4b..8f6a191a7eec5ba0dfb44b1f8671f9841b01ff5b 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -141,7 +141,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
                 continue
 
             agent_virtual_direction = agent.direction
-            agent_speed = agent.speed_data["speed"]
+            agent_speed = agent.speed_counter.speed
             times_per_cell = int(np.reciprocal(agent_speed))
             prediction = np.zeros(shape=(self.max_depth + 1, 5))
             prediction[0] = [0, *agent_virtual_position, agent_virtual_direction, 0]
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 46876ac953535f4c49b57036045b405c6b986cc3..2915e9be2b1d9537631f0639a6f20a9f05955d17 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -261,8 +261,7 @@ class RailEnv(Environment):
         False: Agent cannot provide an action
         """
         return agent.state == TrainState.READY_TO_DEPART or \
-               (agent.state.is_on_map_state() and \
-                fast_isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03) )
+               (agent.state.is_on_map_state() and agent.speed_counter.is_cell_entry )
 
     def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, *,
               random_seed: bool = None) -> Tuple[Dict, Dict]:
@@ -344,19 +343,6 @@ class RailEnv(Environment):
         # Reset agents to initial states
         self.reset_agents()
 
-        # for agent in self.agents:
-        #     # Induce malfunctions
-        #     if activate_agents:
-        #         self.set_agent_active(agent)
-
-        #     self._break_agent(agent)
-
-        #     if agent.malfunction_data["malfunction"] > 0:
-        #         agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING
-
-        #     # Fix agents that finished their malfunction
-        #     self._fix_agent_after_malfunction(agent)
-
         self.num_resets += 1
         self._elapsed_steps = 0
 
@@ -369,14 +355,7 @@ class RailEnv(Environment):
         # Empty the episode store of agent positions
         self.cur_episode = []
 
-        info_dict: Dict = {
-            'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)},
-            'malfunction': {
-                i: agent.malfunction_data['malfunction'] for i, agent in enumerate(self.agents)
-            },
-            'speed': {i: agent.speed_data['speed'] for i, agent in enumerate(self.agents)},
-            'state': {i: agent.state for i, agent in enumerate(self.agents)}
-        }
+        info_dict = self.get_info_dict()
         # Return the new observation vectors for each agent
         observation_dict: Dict = self._get_observations()
         return observation_dict, info_dict
@@ -469,10 +448,12 @@ class RailEnv(Environment):
 
     def get_info_dict(self): # TODO Important : Update this
         info_dict = {
-            "action_required": {},
-            "malfunction": {},
-            "speed": {},
-            "status": {},
+            'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)},
+            'malfunction': {
+                i: agent.malfunction_data['malfunction'] for i, agent in enumerate(self.agents)
+            },
+            'speed': {i: agent.speed_counter.speed for i, agent in enumerate(self.agents)},
+            'state': {i: agent.state for i, agent in enumerate(self.agents)}
         }
         return info_dict
 
diff --git a/flatland/envs/timetable_generators.py b/flatland/envs/timetable_generators.py
index b7876d742f61db830883f828faaf99a39a48bc65..d93c09199b315c488177febe4d1aa423b7a87894 100644
--- a/flatland/envs/timetable_generators.py
+++ b/flatland/envs/timetable_generators.py
@@ -57,7 +57,7 @@ def timetable_generator(agents: List[EnvAgent], distance_map: DistanceMap,
     shortest_paths_lengths = [len_handle_none(v) for k,v in shortest_paths.items()]
 
     # Find mean_shortest_path_time
-    agent_speeds = [agent.speed_data['speed'] for agent in agents]
+    agent_speeds = [agent.speed_counter.speed for agent in agents]
     agent_shortest_path_times = np.array(shortest_paths_lengths)/ np.array(agent_speeds)
     mean_shortest_path_time = np.mean(agent_shortest_path_times)