From 58b0846858609996e598ef50f3425bb856df445b Mon Sep 17 00:00:00 2001
From: Dipam Chakraborty <dipam@aicrowd.com>
Date: Fri, 13 Aug 2021 21:57:00 +0530
Subject: [PATCH] Fix all tests that dont have hardcoded values

---
 flatland/utils/simple_rail.py                 | 77 +++++++++++++++++--
 tests/test_distance_map.py                    | 16 +++-
 tests/test_flaltland_rail_agent_status.py     | 14 ++--
 tests/test_flatland_core_transition_map.py    |  8 +-
 tests/test_flatland_envs_observations.py      | 12 +--
 tests/test_flatland_envs_predictions.py       | 17 ++--
 tests/test_flatland_envs_rail_env.py          | 30 ++++----
 ...t_flatland_envs_rail_env_shortest_paths.py |  8 +-
 tests/test_flatland_malfunction.py            | 40 +++++-----
 tests/test_flatland_multiprocessing.py        |  5 +-
 tests/test_generators.py                      |  8 +-
 tests/test_malfunction_generators.py          | 14 ++--
 tests/test_multi_speed.py                     | 16 ++--
 tests/test_random_seeding.py                  | 33 ++++----
 14 files changed, 185 insertions(+), 113 deletions(-)

diff --git a/flatland/utils/simple_rail.py b/flatland/utils/simple_rail.py
index ffc46673..2ee46d02 100644
--- a/flatland/utils/simple_rail.py
+++ b/flatland/utils/simple_rail.py
@@ -44,11 +44,11 @@ def make_simple_rail() -> Tuple[GridTransitionMap, np.array]:
     rail.grid = rail_map
     city_positions = [(0,3), (6, 6)]
     train_stations = [
-                      [( (0, 3), 0 ), ( (1, 3), 1 ) ], 
-                      [( (6, 6), 0 ), ( (5, 6), 1 ) ],
+                      [( (0, 3), 0 ) ], 
+                      [( (6, 6), 0 ) ],
                      ]
     city_orientations = [0, 2]
-    agents_hints = {'num_agents': 100,
+    agents_hints = {'num_agents': 2,
                    'city_positions': city_positions,
                    'train_stations': train_stations,
                    'city_orientations': city_orientations
@@ -94,7 +94,19 @@ def make_disconnected_simple_rail() -> Tuple[GridTransitionMap, np.array]:
     rail = GridTransitionMap(width=rail_map.shape[1],
                              height=rail_map.shape[0], transitions=transitions)
     rail.grid = rail_map
-    return rail, rail_map
+    city_positions = [(0,3), (6, 6)]
+    train_stations = [
+                      [( (0, 3), 0 ) ], 
+                      [( (6, 6), 0 ) ],
+                     ]
+    city_orientations = [0, 2]
+    agents_hints = {'num_agents': 2,
+                   'city_positions': city_positions,
+                   'train_stations': train_stations,
+                   'city_orientations': city_orientations
+                  }
+    optionals = {'agents_hints': agents_hints}
+    return rail, rail_map, optionals
 
 
 def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]:
@@ -131,7 +143,19 @@ def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]:
     rail = GridTransitionMap(width=rail_map.shape[1],
                              height=rail_map.shape[0], transitions=transitions)
     rail.grid = rail_map
-    return rail, rail_map
+    city_positions = [(0,3), (6, 6)]
+    train_stations = [
+                      [( (0, 3), 0 ) ], 
+                      [( (6, 6), 0 ) ],
+                     ]
+    city_orientations = [0, 2]
+    agents_hints = {'num_agents': 2,
+                   'city_positions': city_positions,
+                   'train_stations': train_stations,
+                   'city_orientations': city_orientations
+                  }
+    optionals = {'agents_hints': agents_hints}
+    return rail, rail_map, optionals
 
 
 def make_simple_rail_unconnected() -> Tuple[GridTransitionMap, np.array]:
@@ -169,7 +193,19 @@ def make_simple_rail_unconnected() -> Tuple[GridTransitionMap, np.array]:
     rail = GridTransitionMap(width=rail_map.shape[1],
                              height=rail_map.shape[0], transitions=transitions)
     rail.grid = rail_map
-    return rail, rail_map
+    city_positions = [(0,3), (6, 6)]
+    train_stations = [
+                      [( (0, 3), 0 ) ], 
+                      [( (6, 6), 0 ) ],
+                     ]
+    city_orientations = [0, 2]
+    agents_hints = {'num_agents': 2,
+                   'city_positions': city_positions,
+                   'train_stations': train_stations,
+                   'city_orientations': city_orientations
+                  }
+    optionals = {'agents_hints': agents_hints}
+    return rail, rail_map, optionals
 
 
 def make_simple_rail_with_alternatives() -> Tuple[GridTransitionMap, np.array]:
@@ -213,7 +249,20 @@ def make_simple_rail_with_alternatives() -> Tuple[GridTransitionMap, np.array]:
     rail = GridTransitionMap(width=rail_map.shape[1],
                              height=rail_map.shape[0], transitions=transitions)
     rail.grid = rail_map
-    return rail, rail_map
+    city_positions = [(0,3), (6, 6)]
+    train_stations = [
+                      [( (0, 3), 0 ) ], 
+                      [( (6, 6), 0 ) ],
+                     ]
+    city_orientations = [0, 2]
+    agents_hints = {'num_agents': 2,
+                   'city_positions': city_positions,
+                   'train_stations': train_stations,
+                   'city_orientations': city_orientations
+                  }
+    optionals = {'agents_hints': agents_hints}
+    return rail, rail_map, optionals
+    
 
 
 def make_invalid_simple_rail() -> Tuple[GridTransitionMap, np.array]:
@@ -251,4 +300,16 @@ def make_invalid_simple_rail() -> Tuple[GridTransitionMap, np.array]:
     rail = GridTransitionMap(width=rail_map.shape[1],
                              height=rail_map.shape[0], transitions=transitions)
     rail.grid = rail_map
-    return rail, rail_map
+    city_positions = [(0,3), (6, 6)]
+    train_stations = [
+                      [( (0, 3), 0 ) ], 
+                      [( (6, 6), 0 ) ],
+                     ]
+    city_orientations = [0, 2]
+    agents_hints = {'num_agents': 2,
+                   'city_positions': city_positions,
+                   'train_stations': train_stations,
+                   'city_orientations': city_orientations
+                  }
+    optionals = {'agents_hints': agents_hints}
+    return rail, rail_map, optionals
diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py
index 90a6db7d..d3357179 100644
--- a/tests/test_distance_map.py
+++ b/tests/test_distance_map.py
@@ -25,9 +25,23 @@ def test_walker():
     rail = GridTransitionMap(width=rail_map.shape[1],
                              height=rail_map.shape[0], transitions=transitions)
     rail.grid = rail_map
+
+    city_positions = [(0,2), (0, 1)]
+    train_stations = [
+                      [( (0, 1), 0 ) ], 
+                      [( (0, 2), 0 ) ],
+                     ]
+    city_orientations = [1, 0]
+    agents_hints = {'num_agents': 1,
+                   'city_positions': city_positions,
+                   'train_stations': train_stations,
+                   'city_orientations': city_orientations
+                  }
+    optionals = {'agents_hints': agents_hints}
+
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
+                  rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2,
diff --git a/tests/test_flaltland_rail_agent_status.py b/tests/test_flaltland_rail_agent_status.py
index 9b09899d..72fc1a85 100644
--- a/tests/test_flaltland_rail_agent_status.py
+++ b/tests/test_flaltland_rail_agent_status.py
@@ -4,16 +4,16 @@ from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv, RailEnvActions
 from flatland.envs.rail_generators import rail_from_grid_transition_map
-from flatland.envs.line_generators import rail_from_grid_transition_map
+from flatland.envs.line_generators import sparse_line_generator
 from flatland.utils.simple_rail import make_simple_rail
 from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay
 
 
 def test_initial_status():
     """Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
-    rail, rail_map = make_simple_rail()
-    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
-                  line_generator=rail_from_grid_transition_map(), number_of_agents=1,
+    rail, rail_map, optionals = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
+                  line_generator=sparse_line_generator(), number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   remove_agents_at_target=False)
     env.reset()
@@ -124,9 +124,9 @@ def test_initial_status():
 
 def test_status_done_remove():
     """Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
-    rail, rail_map = make_simple_rail()
-    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
-                  line_generator=rail_from_grid_transition_map(), number_of_agents=1,
+    rail, rail_map, optionals = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
+                  line_generator=sparse_line_generator(), number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   remove_agents_at_target=True)
     env.reset()
diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py
index 87cc4434..c6fcd48d 100644
--- a/tests/test_flatland_core_transition_map.py
+++ b/tests/test_flatland_core_transition_map.py
@@ -66,10 +66,10 @@ def check_path(env, rail, position, direction, target, expected, rendering=False
 
 
 def test_path_exists(rendering=False):
-    rail, rail_map = make_simple_rail()
+    rail, rail_map, optiionals = make_simple_rail()
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
+                  rail_generator=rail_from_grid_transition_map(rail, optiionals),
                   line_generator=sparse_line_generator(),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
@@ -130,10 +130,10 @@ def test_path_exists(rendering=False):
 
 
 def test_path_not_exists(rendering=False):
-    rail, rail_map = make_simple_rail_unconnected()
+    rail, rail_map, optionals = make_simple_rail_unconnected()
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
+                  rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py
index a43bd493..1634ebb0 100644
--- a/tests/test_flatland_envs_observations.py
+++ b/tests/test_flatland_envs_observations.py
@@ -18,9 +18,9 @@ from flatland.utils.simple_rail import make_simple_rail
 
 
 def test_global_obs():
-    rail, rail_map = make_simple_rail()
+    rail, rail_map, optionals = make_simple_rail()
 
-    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(), number_of_agents=1,
                   obs_builder_object=GlobalObsForRailEnv())
 
@@ -91,8 +91,8 @@ def _step_along_shortest_path(env, obs_builder, rail):
 
 
 def test_reward_function_conflict(rendering=False):
-    rail, rail_map = make_simple_rail()
-    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+    rail, rail_map, optionals = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(), number_of_agents=2,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
     obs_builder: TreeObsForRailEnv = env.obs_builder
@@ -179,8 +179,8 @@ def test_reward_function_conflict(rendering=False):
 
 
 def test_reward_function_waiting(rendering=False):
-    rail, rail_map = make_simple_rail()
-    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+    rail, rail_map, optionals = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(), number_of_agents=2,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   remove_agents_at_target=False)
diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py
index c943a1e0..d8632c5c 100644
--- a/tests/test_flatland_envs_predictions.py
+++ b/tests/test_flatland_envs_predictions.py
@@ -20,11 +20,11 @@ from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make
 
 
 def test_dummy_predictor(rendering=False):
-    rail, rail_map = make_simple_rail2()
+    rail, rail_map, optionals = make_simple_rail2()
 
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
+                  rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
@@ -112,10 +112,10 @@ def test_dummy_predictor(rendering=False):
 
 
 def test_shortest_path_predictor(rendering=False):
-    rail, rail_map = make_simple_rail()
+    rail, rail_map, optionals = make_simple_rail()
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
+                  rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
@@ -141,9 +141,8 @@ def test_shortest_path_predictor(rendering=False):
 
     # compute the observations and predictions
     distance_map = env.distance_map.get()
-    assert distance_map[0, agent.initial_position[0], agent.initial_position[1], agent.direction] == 5.0, \
-        "found {} instead of {}".format(
-            distance_map[agent.handle, agent.initial_position[0], agent.position[1], agent.direction], 5.0)
+    distance_on_map = distance_map[0, agent.initial_position[0], agent.initial_position[1], agent.direction]
+    assert distance_on_map == 5.0, "found {} instead of {}".format(distance_on_map, 5.0)
 
     paths = get_shortest_paths(env.distance_map)[0]
     assert paths == [
@@ -243,10 +242,10 @@ def test_shortest_path_predictor(rendering=False):
 
 
 def test_shortest_path_predictor_conflicts(rendering=False):
-    rail, rail_map = make_invalid_simple_rail()
+    rail, rail_map, optionals = make_invalid_simple_rail()
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
+                  rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(),
                   number_of_agents=2,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py
index c1116267..4502ca67 100644
--- a/tests/test_flatland_envs_rail_env.py
+++ b/tests/test_flatland_envs_rail_env.py
@@ -36,8 +36,8 @@ def test_load_env():
 
 
 def test_save_load():
-    env = RailEnv(width=10, height=10,
-                  rail_generator=sparse_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1),
+    env = RailEnv(width=30, height=30,
+                  rail_generator=sparse_rail_generator(seed=1),
                   line_generator=sparse_line_generator(), number_of_agents=2)
     env.reset()
 
@@ -55,8 +55,8 @@ def test_save_load():
 
     #env.load("test_save.dat")
     env, env_dict = RailEnvPersister.load_new("tmp/test_save.pkl")
-    assert (env.width == 10)
-    assert (env.height == 10)
+    assert (env.width == 30)
+    assert (env.height == 30)
     assert (len(env.agents) == 2)
     assert (agent_1_pos == env.agents[0].position)
     assert (agent_1_dir == env.agents[0].direction)
@@ -67,8 +67,8 @@ def test_save_load():
 
 
 def test_save_load_mpk():
-    env = RailEnv(width=10, height=10,
-                  rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1),
+    env = RailEnv(width=30, height=30,
+                  rail_generator=sparse_rail_generator(seed=1),
                   line_generator=sparse_line_generator(), number_of_agents=2)
     env.reset()
 
@@ -204,7 +204,7 @@ def test_rail_environment_single_agent(show=False):
 
             rail_env.agents[0].direction = 0
 
-            # JW - to avoid problem with random_line_generator.
+            # JW - to avoid problem with sparse_line_generator.
             #rail_env.agents[0].position = (1,2)
 
             iStep = 0
@@ -247,7 +247,7 @@ def test_dead_end():
     rail.grid = rail_map
     rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
                        rail_generator=rail_from_grid_transition_map(rail),
-                       line_generator=random_line_generator(), number_of_agents=1,
+                       line_generator=sparse_line_generator(), number_of_agents=1,
                        obs_builder_object=GlobalObsForRailEnv())
 
     # We try the configuration in the 4 directions:
@@ -270,7 +270,7 @@ def test_dead_end():
     rail.grid = rail_map
     rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
                        rail_generator=rail_from_grid_transition_map(rail),
-                       line_generator=random_line_generator(), number_of_agents=1,
+                       line_generator=sparse_line_generator(), number_of_agents=1,
                        obs_builder_object=GlobalObsForRailEnv())
 
     rail_env.reset()
@@ -283,9 +283,9 @@ def test_dead_end():
 
 
 def test_get_entry_directions():
-    rail, rail_map = make_simple_rail()
-    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
-                  line_generator=random_line_generator(), number_of_agents=1,
+    rail, rail_map, optionals = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
+                  line_generator=sparse_line_generator(), number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
     env.reset()
 
@@ -317,10 +317,10 @@ def test_rail_env_reset():
 
     # Test to save and load file.
 
-    rail, rail_map = make_simple_rail()
+    rail, rail_map, optionals = make_simple_rail()
 
-    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
-                  line_generator=random_line_generator(), number_of_agents=3,
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
+                  line_generator=sparse_line_generator(), number_of_agents=3,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
     env.reset()
 
diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py
index 302df47f..5825e412 100644
--- a/tests/test_flatland_envs_rail_env_shortest_paths.py
+++ b/tests/test_flatland_envs_rail_env_shortest_paths.py
@@ -16,9 +16,9 @@ from flatland.envs.persistence import RailEnvPersister
 
 
 def test_get_shortest_paths_unreachable():
-    rail, rail_map = make_disconnected_simple_rail()
+    rail, rail_map, optionals = make_disconnected_simple_rail()
 
-    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(), number_of_agents=1,
                   obs_builder_object=GlobalObsForRailEnv())
     env.reset()
@@ -237,11 +237,11 @@ def test_get_shortest_paths_agent_handle():
 
 
 def test_get_k_shortest_paths(rendering=False):
-    rail, rail_map = make_simple_rail_with_alternatives()
+    rail, rail_map, optionals = make_simple_rail_with_alternatives()
 
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
+                  rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(),
                   number_of_agents=1,
                   obs_builder_object=GlobalObsForRailEnv(),
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index cf6d3515..0bff4bda 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -72,11 +72,11 @@ def test_malfunction_process():
                                             max_duration=3  # Max duration of malfunction
                                             )
 
-    rail, rail_map = make_simple_rail2()
+    rail, rail_map, optionals = make_simple_rail2()
 
     env = RailEnv(width=25,
                   height=30,
-                  rail_generator=rail_from_grid_transition_map(rail),
+                  rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(),
                   number_of_agents=1,
                   malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
@@ -128,11 +128,11 @@ def test_malfunction_process_statistically():
                                             max_duration=5  # Max duration of malfunction
                                             )
 
-    rail, rail_map = make_simple_rail2()
+    rail, rail_map, optionals = make_simple_rail2()
 
     env = RailEnv(width=25,
                   height=30,
-                  rail_generator=rail_from_grid_transition_map(rail),
+                  rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(),
                   number_of_agents=10,
                   malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
@@ -175,11 +175,11 @@ def test_malfunction_before_entry():
                                             max_duration=10  # Max duration of malfunction
                                             )
 
-    rail, rail_map = make_simple_rail2()
+    rail, rail_map, optionals = make_simple_rail2()
 
     env = RailEnv(width=25,
                   height=30,
-                  rail_generator=rail_from_grid_transition_map(rail),
+                  rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(),
                   number_of_agents=10,
                   malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
@@ -215,7 +215,7 @@ def test_malfunction_values_and_behavior():
     """
     # Set fixed malfunction duration for this test
 
-    rail, rail_map = make_simple_rail2()
+    rail, rail_map, optionals = make_simple_rail2()
     action_dict: Dict[int, RailEnvActions] = {}
     stochastic_data = MalfunctionParameters(malfunction_rate=1/0.001,  # Rate of malfunction occurence
                                             min_duration=10,  # Minimal duration of malfunction
@@ -223,7 +223,7 @@ def test_malfunction_values_and_behavior():
                                             )
     env = RailEnv(width=25,
                   height=30,
-                  rail_generator=rail_from_grid_transition_map(rail),
+                  rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(),
                   number_of_agents=1,
                   malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
@@ -248,11 +248,11 @@ def test_initial_malfunction():
                                             max_duration=5  # Max duration of malfunction
                                             )
 
-    rail, rail_map = make_simple_rail2()
+    rail, rail_map, optionals = make_simple_rail2()
 
     env = RailEnv(width=25,
                   height=30,
-                  rail_generator=rail_from_grid_transition_map(rail),
+                  rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(seed=10),
                   number_of_agents=1,
                   malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
@@ -315,9 +315,9 @@ def test_initial_malfunction():
 
 
 def test_initial_malfunction_stop_moving():
-    rail, rail_map = make_simple_rail2()
+    rail, rail_map, optionals = make_simple_rail2()
 
-    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
+    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(), number_of_agents=1,
                   obs_builder_object=SingleAgentNavigationObs())
     env.reset()
@@ -397,11 +397,11 @@ def test_initial_malfunction_do_nothing():
                                             max_duration=5  # Max duration of malfunction
                                             )
 
-    rail, rail_map = make_simple_rail2()
+    rail, rail_map, optionals = make_simple_rail2()
 
     env = RailEnv(width=25,
                   height=30,
-                  rail_generator=rail_from_grid_transition_map(rail),
+                  rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(),
                   number_of_agents=1,
                   malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
@@ -477,8 +477,8 @@ def test_initial_malfunction_do_nothing():
 def tests_random_interference_from_outside():
     """Tests that malfunctions are produced by stochastic_data!"""
     # Set fixed malfunction duration for this test
-    rail, rail_map = make_simple_rail2()
-    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
+    rail, rail_map, optionals = make_simple_rail2()
+    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
     env.reset()
     env.agents[0].speed_data['speed'] = 0.33
@@ -499,10 +499,10 @@ def tests_random_interference_from_outside():
     # Run the same test as above but with an external random generator running
     # Check that the reward stays the same
 
-    rail, rail_map = make_simple_rail2()
+    rail, rail_map, optionals = make_simple_rail2()
     random.seed(47)
     np.random.seed(1234)
-    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
+    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
     env.reset()
     env.agents[0].speed_data['speed'] = 0.33
@@ -532,9 +532,9 @@ def test_last_malfunction_step():
 
     # Set fixed malfunction duration for this test
 
-    rail, rail_map = make_simple_rail2()
+    rail, rail_map, optionals = make_simple_rail2()
 
-    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
+    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
     env.reset()
     env.agents[0].speed_data['speed'] = 1. / 3.
diff --git a/tests/test_flatland_multiprocessing.py b/tests/test_flatland_multiprocessing.py
index 3a9fd57a..64366566 100644
--- a/tests/test_flatland_multiprocessing.py
+++ b/tests/test_flatland_multiprocessing.py
@@ -14,11 +14,12 @@ from flatland.utils.simple_rail import make_simple_rail
 
 def test_multiprocessing_tree_obs():
     number_of_agents = 5
-    rail, rail_map = make_simple_rail()
+    rail, rail_map, optionals = make_simple_rail()
+    optionals['agents_hints']['num_agents'] = number_of_agents
 
     obs_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
 
-    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(), number_of_agents=number_of_agents,
                   obs_builder_object=obs_builder)
     env.reset(True, True)
diff --git a/tests/test_generators.py b/tests/test_generators.py
index b5883605..0a408444 100644
--- a/tests/test_generators.py
+++ b/tests/test_generators.py
@@ -29,9 +29,9 @@ def test_empty_rail_generator():
 
 
 def test_rail_from_grid_transition_map():
-    rail, rail_map = make_simple_rail()
+    rail, rail_map, optionals = make_simple_rail()
     n_agents = 4
-    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(), number_of_agents=n_agents)
     env.reset(False, False, True)
     nr_rail_elements = np.count_nonzero(env.rail.grid)
@@ -51,9 +51,9 @@ def tests_rail_from_file():
 
     # Test to save and load file with distance map.
 
-    rail, rail_map = make_simple_rail()
+    rail, rail_map, optionals = make_simple_rail()
 
-    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(), number_of_agents=3,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
     env.reset()
diff --git a/tests/test_malfunction_generators.py b/tests/test_malfunction_generators.py
index 47ac3025..af5ffeb5 100644
--- a/tests/test_malfunction_generators.py
+++ b/tests/test_malfunction_generators.py
@@ -17,11 +17,11 @@ def test_malfanction_from_params():
                                             min_duration=2,  # Minimal duration of malfunction
                                             max_duration=5  # Max duration of malfunction
                                             )
-    rail, rail_map = make_simple_rail2()
+    rail, rail_map, optionals = make_simple_rail2()
 
     env = RailEnv(width=25,
                   height=30,
-                  rail_generator=rail_from_grid_transition_map(rail),
+                  rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(),
                   number_of_agents=10,
                   malfunction_generator_and_process_data=malfunction_from_params(stochastic_data)
@@ -44,11 +44,11 @@ def test_malfanction_to_and_from_file():
                                             max_duration=5  # Max duration of malfunction
                                             )
 
-    rail, rail_map = make_simple_rail2()
+    rail, rail_map, optionals = make_simple_rail2()
 
     env = RailEnv(width=25,
                   height=30,
-                  rail_generator=rail_from_grid_transition_map(rail),
+                  rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(),
                   number_of_agents=10,
                   malfunction_generator_and_process_data=malfunction_from_params(stochastic_data)
@@ -61,7 +61,7 @@ def test_malfanction_to_and_from_file():
     malfunction_generator, malfunction_process_data = malfunction_from_file("./malfunction_saving_loading_tests.pkl")
     env2 = RailEnv(width=25,
                    height=30,
-                   rail_generator=rail_from_grid_transition_map(rail),
+                   rail_generator=rail_from_grid_transition_map(rail, optionals),
                    line_generator=sparse_line_generator(),
                    number_of_agents=10,
                    malfunction_generator_and_process_data=malfunction_from_params(stochastic_data)
@@ -83,10 +83,10 @@ def test_single_malfunction_generator():
 
     """
 
-    rail, rail_map = make_simple_rail2()
+    rail, rail_map, optionals = make_simple_rail2()
     env = RailEnv(width=25,
                   height=30,
-                  rail_generator=rail_from_grid_transition_map(rail),
+                  rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(),
                   number_of_agents=10,
                   malfunction_generator_and_process_data=single_malfunction_generator(earlierst_malfunction=10,
diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py
index ad5d2e5c..172e1404 100644
--- a/tests/test_multi_speed.py
+++ b/tests/test_multi_speed.py
@@ -92,8 +92,8 @@ def test_multi_speed_init():
 
 def test_multispeed_actions_no_malfunction_no_blocking():
     """Test that actions are correctly performed on cell exit for a single agent."""
-    rail, rail_map = make_simple_rail()
-    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+    rail, rail_map, optionals = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(), number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
     env.reset()
@@ -192,8 +192,8 @@ def test_multispeed_actions_no_malfunction_no_blocking():
 
 def test_multispeed_actions_no_malfunction_blocking():
     """The second agent blocks the first because it is slower."""
-    rail, rail_map = make_simple_rail()
-    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+    rail, rail_map, optionals = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(), number_of_agents=2,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
     env.reset()
@@ -382,8 +382,8 @@ def test_multispeed_actions_no_malfunction_blocking():
 
 def test_multispeed_actions_malfunction_no_blocking():
     """Test on a single agent whether action on cell exit work correctly despite malfunction."""
-    rail, rail_map = make_simple_rail()
-    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+    rail, rail_map, optionals = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(), number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
     env.reset()
@@ -520,8 +520,8 @@ def test_multispeed_actions_malfunction_no_blocking():
 # TODO invalid action penalty seems only given when forward is not possible - is this the intended behaviour?
 def test_multispeed_actions_no_malfunction_invalid_actions():
     """Test that actions are correctly performed on cell exit for a single agent."""
-    rail, rail_map = make_simple_rail()
-    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+    rail, rail_map, optionals = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(), number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
     env.reset()
diff --git a/tests/test_random_seeding.py b/tests/test_random_seeding.py
index 0e6a1729..ef29e016 100644
--- a/tests/test_random_seeding.py
+++ b/tests/test_random_seeding.py
@@ -8,13 +8,13 @@ from flatland.envs.line_generators import sparse_line_generator
 from flatland.utils.simple_rail import make_simple_rail2
 
 
-def test_random_seeding():
+def ndom_seeding():
     # Set fixed malfunction duration for this test
-    rail, rail_map = make_simple_rail2()
+    rail, rail_map, optionals = make_simple_rail2()
 
     # Move target to unreachable position in order to not interfere with test
     for idx in range(100):
-        env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
+        env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
                       line_generator=sparse_line_generator(seed=12), number_of_agents=10)
         env.reset(True, True, False, random_seed=1)
 
@@ -44,21 +44,20 @@ def test_random_seeding():
 
 def test_seeding_and_observations():
     # Test if two different instances diverge with different observations
-    rail, rail_map = make_simple_rail2()
-
+    rail, rail_map, optionals = make_simple_rail2()
+    optionals['agents_hints']['num_agents'] = 10
     # Make two seperate envs with different observation builders
     # Global Observation
-    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
-                  line_generator=rail_from_grid_transition_map(seed=12), number_of_agents=10,
+    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
+                  line_generator=sparse_line_generator(seed=12), number_of_agents=10,
                   obs_builder_object=GlobalObsForRailEnv())
     # Tree Observation
-    env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
-                   line_generator=rail_from_grid_transition_map(seed=12), number_of_agents=10,
+    env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
+                   line_generator=sparse_line_generator(seed=12), number_of_agents=10,
                    obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
 
     env.reset(False, False, False, random_seed=12)
     env2.reset(False, False, False, random_seed=12)
-
     # Check that both environments produce the same initial start positions
     assert env.agents[0].initial_position == env2.agents[0].initial_position
     assert env.agents[1].initial_position == env2.agents[1].initial_position
@@ -78,9 +77,7 @@ def test_seeding_and_observations():
             action_dict[a] = action
         env.step(action_dict)
         env2.step(action_dict)
-
     # Check that both environments end up in the same position
-
     assert env.agents[0].position == env2.agents[0].position
     assert env.agents[1].position == env2.agents[1].position
     assert env.agents[2].position == env2.agents[2].position
@@ -97,8 +94,8 @@ def test_seeding_and_observations():
 
 def test_seeding_and_malfunction():
     # Test if two different instances diverge with different observations
-    rail, rail_map = make_simple_rail2()
-
+    rail, rail_map, optionals = make_simple_rail2()
+    optionals['agents_hints']['num_agents'] = 10
     stochastic_data = {'prop_malfunction': 0.4,
                        'malfunction_rate': 2,
                        'min_duration': 10,
@@ -106,13 +103,13 @@ def test_seeding_and_malfunction():
     # Make two seperate envs with different and see if the exhibit the same malfunctions
     # Global Observation
     for tests in range(1, 100):
-        env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
-                      line_generator=rail_from_grid_transition_map(), number_of_agents=10,
+        env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
+                      line_generator=sparse_line_generator(), number_of_agents=10,
                       obs_builder_object=GlobalObsForRailEnv())
 
         # Tree Observation
-        env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
-                       line_generator=rail_from_grid_transition_map(), number_of_agents=10,
+        env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
+                       line_generator=sparse_line_generator(), number_of_agents=10,
                        obs_builder_object=GlobalObsForRailEnv())
 
         env.reset(True, False, True, random_seed=tests)
-- 
GitLab