diff --git a/tests/test_env_step_utils.py b/tests/test_env_step_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..739d3d06e7271d2ce54ded07de54957d96c08022
--- /dev/null
+++ b/tests/test_env_step_utils.py
@@ -0,0 +1,61 @@
+import numpy as np
+import numpy as np
+import os
+
+from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters, ParamMalfunctionGen
+
+from flatland.envs.observations import GlobalObsForRailEnv
+# First of all we import the Flatland rail environment
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_env import RailEnvActions
+from flatland.envs.rail_generators import sparse_rail_generator
+#from flatland.envs.sparse_rail_gen import SparseRailGen
+from flatland.envs.schedule_generators import sparse_schedule_generator
+
+
+def get_small_two_agent_env():
+    """Generates a simple 2 city 2 train env returns it after reset"""
+    width = 30  # With of map
+    height = 15  # Height of map
+    nr_trains = 2  # Number of trains that have an assigned task in the env
+    cities_in_map = 2 # Number of cities where agents can start or end
+    seed = 42  # Random seed
+    grid_distribution_of_cities = False  # Type of city distribution, if False cities are randomly placed
+    max_rails_between_cities = 2  # Max number of tracks allowed between cities. This is number of entry point to a city
+    max_rail_in_cities = 6  # Max number of parallel tracks within a city, representing a realistic trainstation
+
+    rail_generator = sparse_rail_generator(max_num_cities=cities_in_map,
+                                        seed=seed,
+                                        grid_mode=grid_distribution_of_cities,
+                                        max_rails_between_cities=max_rails_between_cities,
+                                        max_rail_pairs_in_city=max_rail_in_cities//2,
+                                        )
+    speed_ration_map = {1.: 0.25,  # Fast passenger train
+                    1. / 2.: 0.25,  # Fast freight train
+                    1. / 3.: 0.25,  # Slow commuter train
+                    1. / 4.: 0.25}  # Slow freight train
+
+    schedule_generator = sparse_schedule_generator(speed_ration_map)
+
+
+    stochastic_data = MalfunctionParameters(malfunction_rate=1/10000,  # Rate of malfunction occurence
+                                        min_duration=15,  # Minimal duration of malfunction
+                                        max_duration=50  # Max duration of malfunction
+                                        )
+
+    observation_builder = GlobalObsForRailEnv()
+
+    env = RailEnv(width=width,
+                height=height,
+                rail_generator=rail_generator,
+                schedule_generator=schedule_generator,
+                number_of_agents=nr_trains,
+                obs_builder_object=observation_builder,
+                #malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
+                malfunction_generator=ParamMalfunctionGen(stochastic_data),
+                remove_agents_at_target=True,
+                random_seed=seed)
+
+    env.reset()
+
+    return env
\ No newline at end of file
diff --git a/tests/test_state_machine.py b/tests/test_state_machine.py
new file mode 100644
index 0000000000000000000000000000000000000000..266a8f86589b6033ea67523cab0b31b72ac9d32d
--- /dev/null
+++ b/tests/test_state_machine.py
@@ -0,0 +1,115 @@
+from test_env_step_utils import get_small_two_agent_env
+from flatland.envs.rail_env_action import RailEnvActions
+from flatland.envs.step_utils.states import TrainState
+from flatland.envs.malfunction_generators import Malfunction
+
+class NoMalfunctionGenerator:
+    def generate(self, np_random):
+        return Malfunction(0)
+
+class AlwaysThreeStepMalfunction:
+    def generate(self, np_random):
+        return Malfunction(3)
+
+def test_waiting_no_transition():
+    env = get_small_two_agent_env()
+    env.malfunction_generator = NoMalfunctionGenerator()
+    i_agent = 0
+    ed = env.agents[i_agent].earliest_departure
+    for _ in range(ed-1):
+        env.step({i_agent: RailEnvActions.MOVE_FORWARD})
+        assert env.agents[i_agent].state == TrainState.WAITING
+    
+    
+def test_waiting_to_ready_to_depart():
+    env = get_small_two_agent_env()
+    env.malfunction_generator = NoMalfunctionGenerator()
+    i_agent = 0
+    ed = env.agents[i_agent].earliest_departure
+    for _ in range(ed):
+        env.step({i_agent: RailEnvActions.DO_NOTHING})
+    assert env.agents[i_agent].state == TrainState.READY_TO_DEPART
+
+
+def test_ready_to_depart_to_moving():
+    env = get_small_two_agent_env()
+    env.malfunction_generator = NoMalfunctionGenerator()
+    i_agent = 0
+    ed = env.agents[i_agent].earliest_departure
+    for _ in range(ed):
+        env.step({i_agent: RailEnvActions.DO_NOTHING})
+
+    env.step({i_agent: RailEnvActions.MOVE_FORWARD})
+    assert env.agents[i_agent].state == TrainState.MOVING
+
+def test_moving_to_stopped():
+    env = get_small_two_agent_env()
+    env.malfunction_generator = NoMalfunctionGenerator()
+    i_agent = 0
+    ed = env.agents[i_agent].earliest_departure
+    for _ in range(ed):
+        env.step({i_agent: RailEnvActions.DO_NOTHING})
+
+    env.step({i_agent: RailEnvActions.MOVE_FORWARD})
+    env.step({i_agent: RailEnvActions.STOP_MOVING})
+    assert env.agents[i_agent].state == TrainState.STOPPED
+
+def test_stopped_to_moving():
+    env = get_small_two_agent_env()
+    env.malfunction_generator = NoMalfunctionGenerator()
+    i_agent = 0
+    ed = env.agents[i_agent].earliest_departure
+    for _ in range(ed):
+        env.step({i_agent: RailEnvActions.DO_NOTHING})
+
+    env.step({i_agent: RailEnvActions.MOVE_FORWARD})
+    env.step({i_agent: RailEnvActions.STOP_MOVING})
+    env.step({i_agent: RailEnvActions.MOVE_FORWARD})
+    assert env.agents[i_agent].state == TrainState.MOVING
+
+def test_moving_to_done():
+    env = get_small_two_agent_env()
+    env.malfunction_generator = NoMalfunctionGenerator()
+    i_agent = 1
+    ed = env.agents[i_agent].earliest_departure
+    for _ in range(ed):
+        env.step({i_agent: RailEnvActions.DO_NOTHING})
+
+    for _ in range(50):
+        env.step({i_agent: RailEnvActions.MOVE_FORWARD})
+    assert env.agents[i_agent].state == TrainState.DONE
+
+def test_waiting_to_malfunction():
+    env = get_small_two_agent_env()
+    env.malfunction_generator = AlwaysThreeStepMalfunction()
+    i_agent = 1
+    env.step({i_agent: RailEnvActions.DO_NOTHING})
+    assert env.agents[i_agent].state == TrainState.MALFUNCTION_OFF_MAP
+
+
+def test_ready_to_depart_to_malfunction_off_map():
+    env = get_small_two_agent_env()
+    env.malfunction_generator = NoMalfunctionGenerator()
+    i_agent = 1
+    env.step({i_agent: RailEnvActions.DO_NOTHING})
+    ed = env.agents[i_agent].earliest_departure
+    for _ in range(ed):
+        env.step({i_agent: RailEnvActions.DO_NOTHING}) # This should get into ready to depart
+        
+    env.malfunction_generator = AlwaysThreeStepMalfunction()
+    env.step({i_agent: RailEnvActions.DO_NOTHING})
+    assert env.agents[i_agent].state == TrainState.MALFUNCTION_OFF_MAP
+
+
+def test_malfunction_off_map_to_waiting():
+    env = get_small_two_agent_env()
+    env.malfunction_generator = NoMalfunctionGenerator()
+    i_agent = 1
+    env.step({i_agent: RailEnvActions.DO_NOTHING})
+    ed = env.agents[i_agent].earliest_departure
+    for _ in range(ed):
+        env.step({i_agent: RailEnvActions.DO_NOTHING}) # This should get into ready to depart
+        
+    env.malfunction_generator = AlwaysThreeStepMalfunction()
+    env.step({i_agent: RailEnvActions.DO_NOTHING})
+    assert env.agents[i_agent].state == TrainState.MALFUNCTION_OFF_MAP
\ No newline at end of file