diff --git a/tests/test_environments.py b/tests/test_environments.py
index 925874ba59d452410503af47aebeadb2fb5f07ca..c3329a126684ef9462257f51276f7bae94675b28 100644
--- a/tests/test_environments.py
+++ b/tests/test_environments.py
@@ -3,6 +3,7 @@
 
 from flatland.core.env import RailEnv
 from flatland.core.transitions import Grid4Transitions
+from flatland.core.transitionmap import GridTransitionMap
 import numpy as np
 
 """Tests for `flatland` package."""
@@ -44,7 +45,9 @@ def test_rail_environment_single_agent():
                           north_west_turn]],
                         dtype=np.uint16)
 
-    rail_env = RailEnv(rail_map, number_of_agents=1)
+    rail = GridTransitionMap(width=3, height=3, transitions=transitions)
+    rail.grid = rail_map
+    rail_env = RailEnv(rail, number_of_agents=1)
     for _ in range(200):
         _ = rail_env.reset()