diff --git a/tests/test_environments.py b/tests/test_environments.py index 925874ba59d452410503af47aebeadb2fb5f07ca..c3329a126684ef9462257f51276f7bae94675b28 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -3,6 +3,7 @@ from flatland.core.env import RailEnv from flatland.core.transitions import Grid4Transitions +from flatland.core.transitionmap import GridTransitionMap import numpy as np """Tests for `flatland` package.""" @@ -44,7 +45,9 @@ def test_rail_environment_single_agent(): north_west_turn]], dtype=np.uint16) - rail_env = RailEnv(rail_map, number_of_agents=1) + rail = GridTransitionMap(width=3, height=3, transitions=transitions) + rail.grid = rail_map + rail_env = RailEnv(rail, number_of_agents=1) for _ in range(200): _ = rail_env.reset()