diff --git a/tests/test_eval_timeout.py b/tests/test_eval_timeout.py
index dfc406e3b9d091fc8e9a477ea86fae025e7b1936..6c92db298b3c87ca8597ab113b56ab1c8f208cde 100644
--- a/tests/test_eval_timeout.py
+++ b/tests/test_eval_timeout.py
@@ -8,8 +8,6 @@ import time
 
 from flatland.core.env import Environment
 from flatland.core.env_observation_builder import ObservationBuilder
-from flatland.core.env_prediction_builder import PredictionBuilder
-from flatland.envs.agent_utils import RailAgentStatus, EnvAgent
 
 
 class CustomObservationBuilder(ObservationBuilder):
diff --git a/tests/test_flaltland_rail_agent_status.py b/tests/test_flaltland_rail_agent_status.py
index e3f1ced759fd755db58749cf0215a121a7b13026..82a2089f17cf1d25eb8bb28bd58e6918537035d2 100644
--- a/tests/test_flaltland_rail_agent_status.py
+++ b/tests/test_flaltland_rail_agent_status.py
@@ -1,5 +1,4 @@
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
-from flatland.envs.agent_utils import RailAgentStatus
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv, RailEnvActions
@@ -7,7 +6,7 @@ from flatland.envs.rail_generators import rail_from_grid_transition_map
 from flatland.envs.line_generators import sparse_line_generator
 from flatland.utils.simple_rail import make_simple_rail
 from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay
-
+from flatland.envs.step_utils.states import TrainState
 
 def test_initial_status():
     """Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
@@ -30,7 +29,7 @@ def test_initial_status():
             Replay(
                 position=None,  # not entered grid yet
                 direction=Grid4TransitionsEnum.EAST,
-                status=RailAgentStatus.READY_TO_DEPART,
+                state=TrainState.READY_TO_DEPART,
                 action=RailEnvActions.DO_NOTHING,
                 reward=env.step_penalty * 0.5,
 
@@ -38,35 +37,35 @@ def test_initial_status():
             Replay(
                 position=None,  # not entered grid yet before step
                 direction=Grid4TransitionsEnum.EAST,
-                status=RailAgentStatus.READY_TO_DEPART,
+                state=TrainState.READY_TO_DEPART,
                 action=RailEnvActions.MOVE_LEFT,
                 reward=env.step_penalty * 0.5,  # auto-correction left to forward without penalty!
             ),
             Replay(
                 position=(3, 9),
                 direction=Grid4TransitionsEnum.EAST,
-                status=RailAgentStatus.ACTIVE,
+                state=TrainState.MOVING,
                 action=RailEnvActions.MOVE_LEFT,
                 reward=env.start_penalty + env.step_penalty * 0.5,  # running at speed 0.5
             ),
             Replay(
                 position=(3, 9),
                 direction=Grid4TransitionsEnum.EAST,
-                status=RailAgentStatus.ACTIVE,
+                state=TrainState.MOVING,
                 action=None,
                 reward=env.step_penalty * 0.5,  # running at speed 0.5
             ),
             Replay(
                 position=(3, 8),
                 direction=Grid4TransitionsEnum.WEST,
-                status=RailAgentStatus.ACTIVE,
+                state=TrainState.MOVING,
                 action=RailEnvActions.MOVE_FORWARD,
                 reward=env.step_penalty * 0.5,  # running at speed 0.5
             ),
             Replay(
                 position=(3, 8),
                 direction=Grid4TransitionsEnum.WEST,
-                status=RailAgentStatus.ACTIVE,
+                state=TrainState.MOVING,
                 action=None,
                 reward=env.step_penalty * 0.5,  # running at speed 0.5
 
@@ -76,28 +75,28 @@ def test_initial_status():
                 direction=Grid4TransitionsEnum.WEST,
                 action=RailEnvActions.MOVE_FORWARD,
                 reward=env.step_penalty * 0.5,  # running at speed 0.5
-                status=RailAgentStatus.ACTIVE
+                state=TrainState.MOVING
             ),
             Replay(
                 position=(3, 7),
                 direction=Grid4TransitionsEnum.WEST,
                 action=None,
                 reward=env.step_penalty * 0.5,  # wrong action is corrected to forward without penalty!
-                status=RailAgentStatus.ACTIVE
+                state=TrainState.MOVING
             ),
             Replay(
                 position=(3, 6),
                 direction=Grid4TransitionsEnum.WEST,
                 action=RailEnvActions.MOVE_RIGHT,
                 reward=env.step_penalty * 0.5,  #
-                status=RailAgentStatus.ACTIVE
+                state=TrainState.MOVING
             ),
             Replay(
                 position=(3, 6),
                 direction=Grid4TransitionsEnum.WEST,
                 action=None,
                 reward=env.global_reward,  #
-                status=RailAgentStatus.ACTIVE
+                state=TrainState.MOVING
             ),
             # Replay(
             #     position=(3, 5),
@@ -122,7 +121,7 @@ def test_initial_status():
     )
 
     run_replay_config(env, [test_config], activate_agents=False, skip_reward_check=True)
-    assert env.agents[0].status == RailAgentStatus.DONE
+    assert env.agents[0].state == TrainState.DONE
 
 
 def test_status_done_remove():
@@ -146,7 +145,7 @@ def test_status_done_remove():
             Replay(
                 position=None,  # not entered grid yet
                 direction=Grid4TransitionsEnum.EAST,
-                status=RailAgentStatus.READY_TO_DEPART,
+                state=TrainState.READY_TO_DEPART,
                 action=RailEnvActions.DO_NOTHING,
                 reward=env.step_penalty * 0.5,
 
@@ -154,35 +153,35 @@ def test_status_done_remove():
             Replay(
                 position=None,  # not entered grid yet before step
                 direction=Grid4TransitionsEnum.EAST,
-                status=RailAgentStatus.READY_TO_DEPART,
+                state=TrainState.READY_TO_DEPART,
                 action=RailEnvActions.MOVE_LEFT,
                 reward=env.step_penalty * 0.5,  # auto-correction left to forward without penalty!
             ),
             Replay(
                 position=(3, 9),
                 direction=Grid4TransitionsEnum.EAST,
-                status=RailAgentStatus.ACTIVE,
+                state=TrainState.MOVING,
                 action=RailEnvActions.MOVE_FORWARD,
                 reward=env.start_penalty + env.step_penalty * 0.5,  # running at speed 0.5
             ),
             Replay(
                 position=(3, 9),
                 direction=Grid4TransitionsEnum.EAST,
-                status=RailAgentStatus.ACTIVE,
+                state=TrainState.MOVING,
                 action=None,
                 reward=env.step_penalty * 0.5,  # running at speed 0.5
             ),
             Replay(
                 position=(3, 8),
                 direction=Grid4TransitionsEnum.WEST,
-                status=RailAgentStatus.ACTIVE,
+                state=TrainState.MOVING,
                 action=RailEnvActions.MOVE_FORWARD,
                 reward=env.step_penalty * 0.5,  # running at speed 0.5
             ),
             Replay(
                 position=(3, 8),
                 direction=Grid4TransitionsEnum.WEST,
-                status=RailAgentStatus.ACTIVE,
+                state=TrainState.MOVING,
                 action=None,
                 reward=env.step_penalty * 0.5,  # running at speed 0.5
 
@@ -192,28 +191,28 @@ def test_status_done_remove():
                 direction=Grid4TransitionsEnum.WEST,
                 action=RailEnvActions.MOVE_RIGHT,
                 reward=env.step_penalty * 0.5,  # running at speed 0.5
-                status=RailAgentStatus.ACTIVE
+                state=TrainState.MOVING
             ),
             Replay(
                 position=(3, 7),
                 direction=Grid4TransitionsEnum.WEST,
                 action=None,
                 reward=env.step_penalty * 0.5,  # wrong action is corrected to forward without penalty!
-                status=RailAgentStatus.ACTIVE
+                state=TrainState.MOVING
             ),
             Replay(
                 position=(3, 6),
                 direction=Grid4TransitionsEnum.WEST,
                 action=RailEnvActions.MOVE_FORWARD,
                 reward=env.step_penalty * 0.5,  # done
-                status=RailAgentStatus.ACTIVE
+                state=TrainState.MOVING
             ),
             Replay(
                 position=(3, 6),
                 direction=Grid4TransitionsEnum.WEST,
                 action=None,
                 reward=env.global_reward,  # already done
-                status=RailAgentStatus.ACTIVE
+                state=TrainState.MOVING
             ),
             # Replay(
             #     position=None,
@@ -238,4 +237,4 @@ def test_status_done_remove():
     )
 
     run_replay_config(env, [test_config], activate_agents=False, skip_reward_check=True)
-    assert env.agents[0].status == RailAgentStatus.DONE_REMOVED
+    assert env.agents[0].state == TrainState.DONE
diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py
index 2658813a95d20dac683c94a1fc827fd74eadbdfb..aee47c4009ded6cd4da38a33970a1cf51e08f5b8 100644
--- a/tests/test_flatland_envs_observations.py
+++ b/tests/test_flatland_envs_observations.py
@@ -5,7 +5,6 @@ import numpy as np
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
 from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv, RailEnvActions
@@ -13,6 +12,7 @@ from flatland.envs.rail_generators import rail_from_grid_transition_map
 from flatland.envs.line_generators import sparse_line_generator
 from flatland.utils.rendertools import RenderTool
 from flatland.utils.simple_rail import make_simple_rail
+from flatland.envs.step_utils.states import TrainState
 
 """Tests for `flatland` package."""
 
@@ -106,7 +106,7 @@ def test_reward_function_conflict(rendering=False):
     agent.initial_direction = 0  # north
     agent.target = (3, 9)  # east dead-end
     agent.moving = True
-    agent.status = RailAgentStatus.ACTIVE
+    agent._set_state(TrainState.MOVING)
 
     agent = env.agents[1]
     agent.position = (3, 8)  # east dead-end
@@ -115,13 +115,13 @@ def test_reward_function_conflict(rendering=False):
     agent.initial_direction = 3  # west
     agent.target = (6, 6)  # south dead-end
     agent.moving = True
-    agent.status = RailAgentStatus.ACTIVE
+    agent._set_state(TrainState.MOVING)
 
     env.reset(False, False)
     env.agents[0].moving = True
     env.agents[1].moving = True
-    env.agents[0].status = RailAgentStatus.ACTIVE
-    env.agents[1].status = RailAgentStatus.ACTIVE
+    env.agents[0]._set_state(TrainState.MOVING)
+    env.agents[1]._set_state(TrainState.MOVING)
     env.agents[0].position = (5, 6)
     env.agents[1].position = (3, 8)
     print("\n")
@@ -195,7 +195,7 @@ def test_reward_function_waiting(rendering=False):
     agent.initial_direction = 3  # west
     agent.target = (3, 1)  # west dead-end
     agent.moving = True
-    agent.status = RailAgentStatus.ACTIVE
+    agent._set_state(TrainState.MOVING)
 
     agent = env.agents[1]
     agent.initial_position = (5, 6)  # south dead-end
@@ -204,13 +204,13 @@ def test_reward_function_waiting(rendering=False):
     agent.initial_direction = 0  # north
     agent.target = (3, 8)  # east dead-end
     agent.moving = True
-    agent.status = RailAgentStatus.ACTIVE
+    agent._set_state(TrainState.MOVING)
 
     env.reset(False, False)
     env.agents[0].moving = True
     env.agents[1].moving = True
-    env.agents[0].status = RailAgentStatus.ACTIVE
-    env.agents[1].status = RailAgentStatus.ACTIVE
+    env.agents[0]._set_state(TrainState.MOVING)
+    env.agents[1]._set_state(TrainState.MOVING)
     env.agents[0].position = (3, 8)
     env.agents[1].position = (5, 6)
 
diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py
index ad2187be4bad2df2b7a85438079aa7d1f2bb8a0e..399ec957c155715e30e2868f5bcc51a0c275bee3 100644
--- a/tests/test_flatland_envs_predictions.py
+++ b/tests/test_flatland_envs_predictions.py
@@ -5,7 +5,6 @@ import pprint
 import numpy as np
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
-from flatland.envs.agent_utils import RailAgentStatus
 from flatland.envs.observations import TreeObsForRailEnv, Node
 from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
@@ -16,6 +15,7 @@ from flatland.envs.line_generators import sparse_line_generator
 from flatland.utils.rendertools import RenderTool
 from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail
 from flatland.envs.rail_env_action import RailEnvActions
+from flatland.envs.step_utils.states import TrainState
 
 """Test predictions for `flatland` package."""
 
@@ -135,7 +135,7 @@ def test_shortest_path_predictor(rendering=False):
     agent.initial_direction = 0  # north
     agent.target = (3, 9)  # east dead-end
     agent.moving = True
-    agent.status = RailAgentStatus.ACTIVE
+    agent._set_state(TrainState.MOVING)
 
     env.reset(False, False)
     env.distance_map._compute(env.agents, env.rail)
@@ -269,7 +269,7 @@ def test_shortest_path_predictor_conflicts(rendering=False):
     env.agents[0].initial_direction = 0  # north
     env.agents[0].target = (3, 9)  # east dead-end
     env.agents[0].moving = True
-    env.agents[0].status = RailAgentStatus.ACTIVE
+    env.agents[0]._set_state(TrainState.MOVING)
 
     env.agents[1].initial_position = (3, 8)  # east dead-end
     env.agents[1].position = (3, 8)  # east dead-end
@@ -277,7 +277,7 @@ def test_shortest_path_predictor_conflicts(rendering=False):
     env.agents[1].initial_direction = 3  # west
     env.agents[1].target = (6, 6)  # south dead-end
     env.agents[1].moving = True
-    env.agents[1].status = RailAgentStatus.ACTIVE
+    env.agents[1]._set_state(TrainState.MOVING)
 
     observations, info = env.reset(False, False)
 
@@ -285,8 +285,8 @@ def test_shortest_path_predictor_conflicts(rendering=False):
     env.agent_positions[env.agents[0].position] = 0
     env.agents[1].position = (3, 8)  # east dead-end
     env.agent_positions[env.agents[1].position] = 1
-    env.agents[0].status = RailAgentStatus.ACTIVE
-    env.agents[1].status = RailAgentStatus.ACTIVE
+    env.agents[0]._set_state(TrainState.MOVING)
+    env.agents[1]._set_state(TrainState.MOVING)
 
     observations = env._get_observations()
 
diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py
index 5c12336a1a612cccd3df8beab42a8dcdfe9cdb59..358839f9e7b7368ead6861a00efdb4f36c9e090c 100644
--- a/tests/test_flatland_envs_sparse_rail_generator.py
+++ b/tests/test_flatland_envs_sparse_rail_generator.py
@@ -1315,8 +1315,8 @@ def test_rail_env_action_required_info():
             if step == 0 or info_only_if_action_required['action_required'][a]:
                 action_dict_only_if_action_required.update({a: action})
             else:
-                print("[{}] not action_required {}, speed_data={}".format(step, a,
-                                                                          env_always_action.agents[a].speed_data))
+                print("[{}] not action_required {}, speed_counter={}".format(step, a,
+                                                                          env_always_action.agents[a].speed_counter))
 
         obs_always_action, rewards_always_action, done_always_action, info_always_action = env_always_action.step(
             action_dict_always_action)
@@ -1375,7 +1375,7 @@ def test_rail_env_malfunction_speed_info():
         for a in range(env.get_num_agents()):
             assert info['malfunction'][a] >= 0
             assert info['speed'][a] >= 0 and info['speed'][a] <= 1
-            assert info['speed'][a] == env.agents[a].speed_data['speed']
+            assert info['speed'][a] == env.agents[a].sspeed_counter.speed
 
         env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
 
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index e32e8d9f21120d7566cc027d7f9fa6cb36ded7be..d633351ed3624499aa2e30df9f09031b0b4cf581 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -6,14 +6,14 @@ import numpy as np
 from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.agent_utils import RailAgentStatus
 from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
 from flatland.envs.rail_env import RailEnv, RailEnvActions
 from flatland.envs.rail_generators import rail_from_grid_transition_map
 from flatland.envs.line_generators import sparse_line_generator
 from flatland.utils.simple_rail import make_simple_rail2
 from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
-
+from flatland.envs.step_utils.states import TrainState
+from flatland.envs.step_utils.speed_counter import SpeedCounter
 
 class SingleAgentNavigationObs(ObservationBuilder):
     """
@@ -32,11 +32,11 @@ class SingleAgentNavigationObs(ObservationBuilder):
     def get(self, handle: int = 0) -> List[int]:
         agent = self.env.agents[handle]
 
-        if agent.status == RailAgentStatus.READY_TO_DEPART:
+        if agent.state.is_off_map_state():
             agent_virtual_position = agent.initial_position
-        elif agent.status == RailAgentStatus.ACTIVE:
+        elif agent.state.is_on_map_state():
             agent_virtual_position = agent.position
-        elif agent.status == RailAgentStatus.DONE:
+        elif agent.state == TrainState.DONE:
             agent_virtual_position = agent.target
         else:
             return None
@@ -85,7 +85,7 @@ def test_malfunction_process():
     obs, info = env.reset(False, False, random_seed=10)
     for a_idx in range(len(env.agents)):
         env.agents[a_idx].position =  env.agents[a_idx].initial_position
-        env.agents[a_idx].status = RailAgentStatus.ACTIVE
+        env.agents[a_idx].state = TrainState.MOVING
 
     agent_halts = 0
     total_down_time = 0
@@ -297,7 +297,7 @@ def test_initial_malfunction():
                 reward=env.step_penalty  # running at speed 1.0
             )
         ],
-        speed=env.agents[0].speed_data['speed'],
+        speed=env.agents[0].speed_counter.speed,
         target=env.agents[0].target,
         initial_position=(3, 2),
         initial_direction=Grid4TransitionsEnum.EAST,
@@ -315,7 +315,7 @@ def test_initial_malfunction_stop_moving():
     
     env._max_episode_steps = 1000
 
-    print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status)
+    print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].state)
 
     set_penalties_for_replay(env)
     replay_config = ReplayConfig(
@@ -327,7 +327,7 @@ def test_initial_malfunction_stop_moving():
                 set_malfunction=3,
                 malfunction=3,
                 reward=env.step_penalty,  # full step penalty when stopped
-                status=RailAgentStatus.READY_TO_DEPART
+                state=TrainState.READY_TO_DEPART
             ),
             Replay(
                 position=(3, 2),
@@ -335,7 +335,7 @@ def test_initial_malfunction_stop_moving():
                 action=RailEnvActions.DO_NOTHING,
                 malfunction=2,
                 reward=env.step_penalty,  # full step penalty when stopped
-                status=RailAgentStatus.ACTIVE
+                state=TrainState.READY_TO_DEPART
             ),
             # malfunction stops in the next step and we're still at the beginning of the cell
             # --> if we take action STOP_MOVING, agent should restart without moving
@@ -346,7 +346,7 @@ def test_initial_malfunction_stop_moving():
                 action=RailEnvActions.STOP_MOVING,
                 malfunction=1,
                 reward=env.step_penalty,  # full step penalty while stopped
-                status=RailAgentStatus.ACTIVE
+                state=TrainState.STOPPED
             ),
             # we have stopped and do nothing --> should stand still
             Replay(
@@ -355,7 +355,7 @@ def test_initial_malfunction_stop_moving():
                 action=RailEnvActions.DO_NOTHING,
                 malfunction=0,
                 reward=env.step_penalty,  # full step penalty while stopped
-                status=RailAgentStatus.ACTIVE
+                state=TrainState.STOPPED
             ),
             # we start to move forward --> should go to next cell now
             Replay(
@@ -364,7 +364,7 @@ def test_initial_malfunction_stop_moving():
                 action=RailEnvActions.MOVE_FORWARD,
                 malfunction=0,
                 reward=env.start_penalty + env.step_penalty * 1.0,  # full step penalty while stopped
-                status=RailAgentStatus.ACTIVE
+                state=TrainState.STOPPED
             ),
             Replay(
                 position=(3, 3),
@@ -372,10 +372,10 @@ def test_initial_malfunction_stop_moving():
                 action=RailEnvActions.MOVE_FORWARD,
                 malfunction=0,
                 reward=env.step_penalty * 1.0,  # full step penalty while stopped
-                status=RailAgentStatus.ACTIVE
+                state=TrainState.STOPPED
             )
         ],
-        speed=env.agents[0].speed_data['speed'],
+        speed=env.agents[0].speed_counter.speed,
         target=env.agents[0].target,
         initial_position=(3, 2),
         initial_direction=Grid4TransitionsEnum.EAST,
@@ -412,7 +412,7 @@ def test_initial_malfunction_do_nothing():
                 set_malfunction=3,
                 malfunction=3,
                 reward=env.step_penalty,  # full step penalty while malfunctioning
-                status=RailAgentStatus.READY_TO_DEPART
+                state=TrainState.READY_TO_DEPART
             ),
             Replay(
                 position=(3, 2),
@@ -420,7 +420,7 @@ def test_initial_malfunction_do_nothing():
                 action=RailEnvActions.DO_NOTHING,
                 malfunction=2,
                 reward=env.step_penalty,  # full step penalty while malfunctioning
-                status=RailAgentStatus.ACTIVE
+                state=TrainState.ACTIVE
             ),
             # malfunction stops in the next step and we're still at the beginning of the cell
             # --> if we take action DO_NOTHING, agent should restart without moving
@@ -431,7 +431,7 @@ def test_initial_malfunction_do_nothing():
                 action=RailEnvActions.DO_NOTHING,
                 malfunction=1,
                 reward=env.step_penalty,  # full step penalty while stopped
-                status=RailAgentStatus.ACTIVE
+                state=TrainState.ACTIVE
             ),
             # we haven't started moving yet --> stay here
             Replay(
@@ -440,7 +440,7 @@ def test_initial_malfunction_do_nothing():
                 action=RailEnvActions.DO_NOTHING,
                 malfunction=0,
                 reward=env.step_penalty,  # full step penalty while stopped
-                status=RailAgentStatus.ACTIVE
+                state=TrainState.ACTIVE
             ),
 
             Replay(
@@ -449,7 +449,7 @@ def test_initial_malfunction_do_nothing():
                 action=RailEnvActions.MOVE_FORWARD,
                 malfunction=0,
                 reward=env.start_penalty + env.step_penalty * 1.0,  # start penalty + step penalty for speed 1.0
-                status=RailAgentStatus.ACTIVE
+                state=TrainState.ACTIVE
             ),  # we start to move forward --> should go to next cell now
             Replay(
                 position=(3, 3),
@@ -457,10 +457,10 @@ def test_initial_malfunction_do_nothing():
                 action=RailEnvActions.MOVE_FORWARD,
                 malfunction=0,
                 reward=env.step_penalty * 1.0,  # step penalty for speed 1.0
-                status=RailAgentStatus.ACTIVE
+                state=TrainState.ACTIVE
             )
         ],
-        speed=env.agents[0].speed_data['speed'],
+        speed=env.agents[0].speed_counter.speed,
         target=env.agents[0].target,
         initial_position=(3, 2),
         initial_direction=Grid4TransitionsEnum.EAST,
@@ -475,7 +475,7 @@ def tests_random_interference_from_outside():
     env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
     env.reset()
-    env.agents[0].speed_data['speed'] = 0.33
+    env.agents[0].speed_counter = SpeedCounter(speed=0.33)
     env.reset(False, False, random_seed=10)
     env_data = []
 
@@ -501,7 +501,7 @@ def tests_random_interference_from_outside():
     env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
     env.reset()
-    env.agents[0].speed_data['speed'] = 0.33
+    env.agents[0].speed_counter = SpeedCounter(speed=0.33)
     env.reset(False, False, random_seed=10)
 
     dummy_list = [1, 2, 6, 7, 8, 9, 4, 5, 4]
@@ -536,7 +536,7 @@ def test_last_malfunction_step():
     env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
     env.reset()
-    env.agents[0].speed_data['speed'] = 1. / 3.
+    env.agents[0].speed_counter = SpeedCounter(speed=1./3.)
     env.agents[0].initial_position = (6, 6)
     env.agents[0].initial_direction = 2
     env.agents[0].target = (0, 3)
@@ -546,7 +546,7 @@ def test_last_malfunction_step():
     env.reset(False, False)
     for a_idx in range(len(env.agents)):
         env.agents[a_idx].position =  env.agents[a_idx].initial_position
-        env.agents[a_idx].status = RailAgentStatus.ACTIVE
+        env.agents[a_idx].state = TrainState.ACTIVE
     # Force malfunction to be off at beginning and next malfunction to happen in 2 steps
     env.agents[0].malfunction_data['next_malfunction'] = 2
     env.agents[0].malfunction_data['malfunction'] = 0
@@ -565,13 +565,13 @@ def test_last_malfunction_step():
         if env.agents[0].malfunction_data['malfunction'] < 1:
             agent_can_move = True
         # Store the position before and after the step
-        pre_position = env.agents[0].speed_data['position_fraction']
+        pre_position = env.agents[0].speed_counter.counter
         _, reward, _, _ = env.step(action_dict)
         # Check if the agent is still allowed to move in this step
 
         if env.agents[0].malfunction_data['malfunction'] > 0:
             agent_can_move = False
-        post_position = env.agents[0].speed_data['position_fraction']
+        post_position = env.agents[0].speed_counter.counter
         # Assert that the agent moved while it was still allowed
         if agent_can_move:
             assert pre_position != post_position
diff --git a/tests/test_generators.py b/tests/test_generators.py
index 67f883746f2767bc98a090285428d7d377c905a1..7d91bce89bd2d840f433de9f895b29e5a822cf3d 100644
--- a/tests/test_generators.py
+++ b/tests/test_generators.py
@@ -10,7 +10,7 @@ from flatland.envs.rail_generators import rail_from_grid_transition_map, rail_fr
 from flatland.envs.line_generators import sparse_line_generator, line_from_file
 from flatland.utils.simple_rail import make_simple_rail
 from flatland.envs.persistence import RailEnvPersister
-from flatland.envs.agent_utils import RailAgentStatus
+from flatland.envs.step_utils.states import TrainState
 
 
 def test_empty_rail_generator():
@@ -35,7 +35,7 @@ def test_rail_from_grid_transition_map():
 
     for a_idx in range(len(env.agents)):
         env.agents[a_idx].position =  env.agents[a_idx].initial_position
-        env.agents[a_idx].status = RailAgentStatus.ACTIVE
+        env.agents[a_idx]._set_state(TrainState.MOVING)
 
     nr_rail_elements = np.count_nonzero(env.rail.grid)
 
diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py
index 851d849d1246773d7d06b5f38ed0eef820f74a56..1ea959a251e9dd672db4a71a11e3bd76bfced433 100644
--- a/tests/test_global_observation.py
+++ b/tests/test_global_observation.py
@@ -1,10 +1,11 @@
 import numpy as np
 
-from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
+from flatland.envs.agent_utils import EnvAgent
 from flatland.envs.observations import GlobalObsForRailEnv
 from flatland.envs.rail_env import RailEnv, RailEnvActions
 from flatland.envs.rail_generators import sparse_rail_generator
 from flatland.envs.line_generators import sparse_line_generator
+from flatland.envs.step_utils.states import TrainState
 
 
 def test_get_global_observation():
@@ -37,7 +38,7 @@ def test_get_global_observation():
     obs, all_rewards, done, _ = env.step({i: RailEnvActions.MOVE_FORWARD for i in range(number_of_agents)})
     for i in range(len(env.agents)):
         agent: EnvAgent = env.agents[i]
-        print("[{}] status={}, position={}, target={}, initial_position={}".format(i, agent.status, agent.position,
+        print("[{}] state={}, position={}, target={}, initial_position={}".format(i, agent.state, agent.position,
                                                                                    agent.target,
                                                                                    agent.initial_position))
 
@@ -65,19 +66,19 @@ def test_get_global_observation():
         # test first channel of obs_agents_state: direction at own position
         for r in range(env.height):
             for c in range(env.width):
-                if (agent.status == RailAgentStatus.ACTIVE or agent.status == RailAgentStatus.DONE) and (
+                if (agent.state.is_on_map_state() or agent.state == TrainState.DONE) and (
                     r, c) == agent.position:
                     assert np.isclose(obs_agents_state[(r, c)][0], agent.direction), \
-                        "agent {} in status {} at {} expected to contain own direction {}, found {}" \
-                            .format(i, agent.status, (r, c), agent.direction, obs_agents_state[(r, c)][0])
-                elif (agent.status == RailAgentStatus.READY_TO_DEPART) and (r, c) == agent.initial_position:
+                        "agent {} in state {} at {} expected to contain own direction {}, found {}" \
+                            .format(i, agent.state, (r, c), agent.direction, obs_agents_state[(r, c)][0])
+                elif (agent.state == TrainState.READY_TO_DEPART) and (r, c) == agent.initial_position:
                     assert np.isclose(obs_agents_state[(r, c)][0], agent.direction), \
-                        "agent {} in status {} at {} expected to contain own direction {}, found {}" \
-                            .format(i, agent.status, (r, c), agent.direction, obs_agents_state[(r, c)][0])
+                        "agent {} in state {} at {} expected to contain own direction {}, found {}" \
+                            .format(i, agent.state, (r, c), agent.direction, obs_agents_state[(r, c)][0])
                 else:
                     assert np.isclose(obs_agents_state[(r, c)][0], -1), \
-                        "agent {} in status {} at {} expected contain -1 found {}" \
-                            .format(i, agent.status, (r, c), obs_agents_state[(r, c)][0])
+                        "agent {} in state {} at {} expected contain -1 found {}" \
+                            .format(i, agent.state, (r, c), obs_agents_state[(r, c)][0])
 
         # test second channel of obs_agents_state: direction at other agents position
         for r in range(env.height):
@@ -86,45 +87,45 @@ def test_get_global_observation():
                 for other_i, other_agent in enumerate(env.agents):
                     if i == other_i:
                         continue
-                    if other_agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and (
+                    if other_agent.state in [TrainState.MOVING, TrainState.MALFUNCTION, TrainState.STOPPED, TrainState.DONE] and (
                         r, c) == other_agent.position:
                         assert np.isclose(obs_agents_state[(r, c)][1], other_agent.direction), \
-                            "agent {} in status {} at {} should see other agent with direction {}, found = {}" \
-                                .format(i, agent.status, (r, c), other_agent.direction, obs_agents_state[(r, c)][1])
+                            "agent {} in state {} at {} should see other agent with direction {}, found = {}" \
+                                .format(i, agent.state, (r, c), other_agent.direction, obs_agents_state[(r, c)][1])
                     has_agent = True
                 if not has_agent:
                     assert np.isclose(obs_agents_state[(r, c)][1], -1), \
-                        "agent {} in status {} at {} should see no other agent direction (-1), found = {}" \
-                            .format(i, agent.status, (r, c), obs_agents_state[(r, c)][1])
+                        "agent {} in state {} at {} should see no other agent direction (-1), found = {}" \
+                            .format(i, agent.state, (r, c), obs_agents_state[(r, c)][1])
 
         # test third and fourth channel of obs_agents_state: malfunction and speed of own or other agent in the grid
         for r in range(env.height):
             for c in range(env.width):
                 has_agent = False
                 for other_i, other_agent in enumerate(env.agents):
-                    if other_agent.status in [RailAgentStatus.ACTIVE,
-                                              RailAgentStatus.DONE] and other_agent.position == (r, c):
+                    if other_agent.state in [TrainState.MOVING, TrainState.MALFUNCTION, TrainState.STOPPED,
+                                              TrainState.DONE] and other_agent.position == (r, c):
                         assert np.isclose(obs_agents_state[(r, c)][2], other_agent.malfunction_data['malfunction']), \
-                            "agent {} in status {} at {} should see agent malfunction {}, found = {}" \
-                                .format(i, agent.status, (r, c), other_agent.malfunction_data['malfunction'],
+                            "agent {} in state {} at {} should see agent malfunction {}, found = {}" \
+                                .format(i, agent.state, (r, c), other_agent.malfunction_data['malfunction'],
                                         obs_agents_state[(r, c)][2])
-                        assert np.isclose(obs_agents_state[(r, c)][3], other_agent.speed_data['speed'])
+                        assert np.isclose(obs_agents_state[(r, c)][3], other_agent.speed_counter.speed)
                         has_agent = True
                 if not has_agent:
                     assert np.isclose(obs_agents_state[(r, c)][2], -1), \
-                        "agent {} in status {} at {} should see no agent malfunction (-1), found = {}" \
-                            .format(i, agent.status, (r, c), obs_agents_state[(r, c)][2])
+                        "agent {} in state {} at {} should see no agent malfunction (-1), found = {}" \
+                            .format(i, agent.state, (r, c), obs_agents_state[(r, c)][2])
                     assert np.isclose(obs_agents_state[(r, c)][3], -1), \
-                        "agent {} in status {} at {} should see no agent speed (-1), found = {}" \
-                            .format(i, agent.status, (r, c), obs_agents_state[(r, c)][3])
+                        "agent {} in state {} at {} should see no agent speed (-1), found = {}" \
+                            .format(i, agent.state, (r, c), obs_agents_state[(r, c)][3])
 
         # test fifth channel of obs_agents_state: number of agents ready to depart in to this cell
         for r in range(env.height):
             for c in range(env.width):
                 count = 0
                 for other_i, other_agent in enumerate(env.agents):
-                    if other_agent.status == RailAgentStatus.READY_TO_DEPART and other_agent.initial_position == (r, c):
+                    if other_agent.state == TrainState.READY_TO_DEPART and other_agent.initial_position == (r, c):
                         count += 1
                 assert np.isclose(obs_agents_state[(r, c)][4], count), \
-                    "agent {} in status {} at {} should see {} agents ready to depart, found{}" \
-                        .format(i, agent.status, (r, c), count, obs_agents_state[(r, c)][4])
+                    "agent {} in state {} at {} should see {} agents ready to depart, found{}" \
+                        .format(i, agent.state, (r, c), count, obs_agents_state[(r, c)][4])
diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py
index 561057d81b431dfbb87b904f7a57e6fcbf84f84e..50565e96bc5716f032af189e25b79622d3ca3586 100644
--- a/tests/test_multi_speed.py
+++ b/tests/test_multi_speed.py
@@ -8,7 +8,7 @@ from flatland.envs.rail_generators import sparse_rail_generator, rail_from_grid_
 from flatland.envs.line_generators import sparse_line_generator
 from flatland.utils.simple_rail import make_simple_rail
 from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay
-from flatland.envs.agent_utils import RailAgentStatus
+from flatland.envs.step_utils.states import TrainState
 
 
 # Use the sparse_rail_generator to generate feasible network configurations with corresponding tasks
@@ -65,13 +65,13 @@ def test_multi_speed_init():
 
     for a_idx in range(len(env.agents)):
         env.agents[a_idx].position =  env.agents[a_idx].initial_position
-        env.agents[a_idx].status = RailAgentStatus.ACTIVE
+        env.agents[a_idx]._set_state(TrainState.MOVING)
 
     # Here you can also further enhance the provided observation by means of normalization
     # See training navigation example in the baseline repository
     old_pos = []
     for i_agent in range(env.get_num_agents()):
-        env.agents[i_agent].speed_data['speed'] = 1. / (i_agent + 1)
+        env.agents[i_agent].speed_counter.speed = 1. / (i_agent + 1)
         old_pos.append(env.agents[i_agent].position)
         print(env.agents[i_agent].position)
     # Run episode
diff --git a/tests/test_speed_classes.py b/tests/test_speed_classes.py
index 3cfe1b1c7f58786cf0caacde629fa3a6c704230d..66f1fbf06eaeb70ed39ac8aa35c93f0fa11c6a32 100644
--- a/tests/test_speed_classes.py
+++ b/tests/test_speed_classes.py
@@ -23,7 +23,7 @@ def test_rail_env_speed_intializer():
                   rail_generator=sparse_rail_generator(), line_generator=sparse_line_generator(),
                   number_of_agents=10)
     env.reset()
-    actual_speeds = list(map(lambda agent: agent.speed_data['speed'], env.agents))
+    actual_speeds = list(map(lambda agent: agent.speed_counter.speed, env.agents))
 
     expected_speed_set = set(speed_ratio_map.keys())
 
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 4b72679ed6a1ceac1f266760d1871c6fc405e6dc..85e6a2755ac66ffeb15a7a8b2d0f4c9de9652e80 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -5,13 +5,15 @@ import numpy as np
 from attr import attrs, attrib
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
-from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
+from flatland.envs.agent_utils import EnvAgent
 from flatland.envs.malfunction_generators import MalfunctionParameters, malfunction_from_params
 from flatland.envs.rail_env import RailEnvActions, RailEnv
 from flatland.envs.rail_generators import RailGenerator
 from flatland.envs.line_generators import LineGenerator
 from flatland.utils.rendertools import RenderTool
 from flatland.envs.persistence import RailEnvPersister
+from flatland.envs.step_utils.states import TrainState
+from flatland.envs.step_utils.speed_counter import SpeedCounter
 
 @attrs
 class Replay(object):
@@ -21,7 +23,7 @@ class Replay(object):
     malfunction = attrib(default=0, type=int)
     set_malfunction = attrib(default=None, type=Optional[int])
     reward = attrib(default=None, type=Optional[float])
-    status = attrib(default=None, type=Optional[RailAgentStatus])
+    state = attrib(default=None, type=Optional[TrainState])
 
 
 @attrs
@@ -86,12 +88,12 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
                 agent.initial_direction = test_config.initial_direction
                 agent.direction = test_config.initial_direction
                 agent.target = test_config.target
-                agent.speed_data['speed'] = test_config.speed
+                agent.speed_counter = SpeedCounter(speed=test_config.speed)
             env.reset(False, False)
             if activate_agents:
                 for a_idx in range(len(env.agents)):
                     env.agents[a_idx].position =  env.agents[a_idx].initial_position
-                    env.agents[a_idx].status = RailAgentStatus.ACTIVE
+                    env.agents[a_idx]._set_state(TrainState.MOVING)
 
         def _assert(a, actual, expected, msg):
             print("[{}] verifying {} on agent {}: actual={}, expected={}".format(step, msg, a, actual, expected))
@@ -108,12 +110,12 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
 
             _assert(a, agent.position, replay.position, 'position')
             _assert(a, agent.direction, replay.direction, 'direction')
-            if replay.status is not None:
-                _assert(a, agent.status, replay.status, 'status')
+            if replay.state is not None:
+                _assert(a, agent.state, replay.state, 'state')
 
             if replay.action is not None:
                 assert info_dict['action_required'][
-                           a] == True or agent.status == RailAgentStatus.READY_TO_DEPART, "[{}] agent {} expecting action_required={} or agent status READY_TO_DEPART".format(
+                           a] == True or agent.state == TrainState.READY_TO_DEPART, "[{}] agent {} expecting action_required={} or agent status READY_TO_DEPART".format(
                     step, a, True)
                 action_dict[a] = replay.action
             else: