From dece6c1673f53ce3cc40a8b2dceeaca1d7772e6f Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Tue, 27 Aug 2019 11:22:26 +0200
Subject: [PATCH] #141 different agent classes

---
 examples/complex_rail_benchmark.py       |   2 +
 examples/custom_observation_example.py   |   9 +-
 examples/debugging_example_DELETE.py     |  11 +-
 examples/simple_example_3.py             |   2 +
 examples/training_example.py             |   3 +
 flatland/cli.py                          |  46 +++---
 flatland/envs/agent_generators.py        | 182 +++++++++++++++++++++++
 flatland/envs/generators.py              |  94 +++---------
 flatland/envs/grid4_generators_utils.py  |  82 +---------
 flatland/envs/rail_env.py                |  37 +++--
 tests/test_distance_map.py               |   2 +
 tests/test_flatland_envs_observations.py |   4 +
 tests/test_flatland_envs_predictions.py  |   4 +
 tests/test_flatland_envs_rail_env.py     |   5 +
 tests/test_flatland_malfunction.py       |   2 +
 tests/test_multi_speed.py                |   3 +
 tests/test_speed_classes.py              |   9 +-
 tests/tests_generators.py                |  18 ++-
 18 files changed, 311 insertions(+), 204 deletions(-)
 create mode 100644 flatland/envs/agent_generators.py

diff --git a/examples/complex_rail_benchmark.py b/examples/complex_rail_benchmark.py
index 44e4b534..624ad669 100644
--- a/examples/complex_rail_benchmark.py
+++ b/examples/complex_rail_benchmark.py
@@ -3,6 +3,7 @@ import random
 
 import numpy as np
 
+from flatland.envs.agent_generators import complex_rail_generator_agents_placer
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.rail_env import RailEnv
 
@@ -15,6 +16,7 @@ def run_benchmark():
     # Example generate a random rail
     env = RailEnv(width=15, height=15,
                   rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12),
+                  agent_generator=complex_rail_generator_agents_placer(),
                   number_of_agents=5)
 
     n_trials = 20
diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py
index 723bb110..401ff94a 100644
--- a/examples/custom_observation_example.py
+++ b/examples/custom_observation_example.py
@@ -5,6 +5,7 @@ import numpy as np
 
 from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid_utils import coordinate_to_position
+from flatland.envs.agent_generators import complex_rail_generator_agents_placer
 from flatland.envs.generators import random_rail_generator, complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
@@ -20,6 +21,7 @@ class SimpleObs(ObservationBuilder):
     Simplest observation builder. The object returns observation vectors with 5 identical components,
     all equal to the ID of the respective agent.
     """
+
     def __init__(self):
         self.observation_space = [5]
 
@@ -53,6 +55,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
     E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
     will be [1, 0, 0].
     """
+
     def __init__(self):
         super().__init__(max_depth=0)
         self.observation_space = [3]
@@ -90,6 +93,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
 env = RailEnv(width=7,
               height=7,
               rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0),
+              agent_generator=complex_rail_generator_agents_placer(),
               number_of_agents=1,
               obs_builder_object=SingleAgentNavigationObs())
 
@@ -97,8 +101,8 @@ obs = env.reset()
 env_renderer = RenderTool(env, gl="PILSVG")
 env_renderer.render_env(show=True, frames=True, show_observations=True)
 for step in range(100):
-    action = np.argmax(obs[0])+1
-    obs, all_rewards, done, _ = env.step({0:action})
+    action = np.argmax(obs[0]) + 1
+    obs, all_rewards, done, _ = env.step({0: action})
     print("Rewards: ", all_rewards, "  [done=", done, "]")
     env_renderer.render_env(show=True, frames=True, show_observations=True)
     time.sleep(0.1)
@@ -200,6 +204,7 @@ CustomObsBuilder = ObservePredictions(CustomPredictor)
 env = RailEnv(width=10,
               height=10,
               rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
+              agent_generator=complex_rail_generator_agents_placer(),
               number_of_agents=3,
               obs_builder_object=CustomObsBuilder)
 
diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py
index 2c0f8145..68fdc8ab 100644
--- a/examples/debugging_example_DELETE.py
+++ b/examples/debugging_example_DELETE.py
@@ -3,6 +3,7 @@ import time
 
 import numpy as np
 
+from flatland.envs.agent_generators import complex_rail_generator_agents_placer
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
@@ -11,6 +12,7 @@ from flatland.utils.rendertools import RenderTool
 random.seed(1)
 np.random.seed(1)
 
+
 class SingleAgentNavigationObs(TreeObsForRailEnv):
     """
     We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
@@ -21,6 +23,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
     E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
     will be [1, 0, 0].
     """
+
     def __init__(self):
         super().__init__(max_depth=0)
         self.observation_space = [3]
@@ -58,6 +61,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
 env = RailEnv(width=14,
               height=14,
               rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0),
+              agent_generator=complex_rail_generator_agents_placer(),
               number_of_agents=2,
               obs_builder_object=SingleAgentNavigationObs())
 
@@ -67,11 +71,11 @@ env_renderer.render_env(show=True, frames=True, show_observations=False)
 for step in range(100):
     actions = {}
     for i in range(len(obs)):
-        actions[i] = np.argmax(obs[i])+1
+        actions[i] = np.argmax(obs[i]) + 1
 
-    if step%5 == 0:
+    if step % 5 == 0:
         print("Agent halts")
-        actions[0] = 4 # Halt
+        actions[0] = 4  # Halt
 
     obs, all_rewards, done, _ = env.step(actions)
     if env.agents[0].malfunction_data['malfunction'] > 0:
@@ -82,4 +86,3 @@ for step in range(100):
     if done["__all__"]:
         break
 env_renderer.close_window()
-
diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py
index 5aa03d8f..1e20fcca 100644
--- a/examples/simple_example_3.py
+++ b/examples/simple_example_3.py
@@ -2,6 +2,7 @@ import random
 
 import numpy as np
 
+from flatland.envs.agent_generators import complex_rail_generator_agents_placer
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
@@ -13,6 +14,7 @@ np.random.seed(1)
 env = RailEnv(width=7,
               height=7,
               rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
+              agent_generator=complex_rail_generator_agents_placer(),
               number_of_agents=2,
               obs_builder_object=TreeObsForRailEnv(max_depth=2))
 
diff --git a/examples/training_example.py b/examples/training_example.py
index d125be15..f339d329 100644
--- a/examples/training_example.py
+++ b/examples/training_example.py
@@ -1,5 +1,6 @@
 import numpy as np
 
+from flatland.envs.agent_generators import complex_rail_generator_agents_placer
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv, LocalObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
@@ -16,11 +17,13 @@ LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2)
 env = RailEnv(width=20,
               height=20,
               rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
+              agent_generator=complex_rail_generator_agents_placer(),
               obs_builder_object=TreeObservation,
               number_of_agents=3)
 
 env_renderer = RenderTool(env, gl="PILSVG", )
 
+
 # Import your own Agent or use RLlib to train agents on Flatland
 # As an example we use a random agent here
 
diff --git a/flatland/cli.py b/flatland/cli.py
index 32e8d9dc..56b2feab 100644
--- a/flatland/cli.py
+++ b/flatland/cli.py
@@ -2,29 +2,33 @@
 
 """Console script for flatland."""
 import sys
+import time
+
 import click
 import numpy as np
-import time
+import redis
+
+from flatland.envs.agent_generators import complex_rail_generator_agents_placer
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.rail_env import RailEnv
-from flatland.utils.rendertools import RenderTool
 from flatland.evaluators.service import FlatlandRemoteEvaluationService
-import redis
+from flatland.utils.rendertools import RenderTool
 
 
 @click.command()
 def demo(args=None):
     """Demo script to check installation"""
     env = RailEnv(
-            width=15,
-            height=15,
-            rail_generator=complex_rail_generator(
-                                    nr_start_goal=10,
-                                    nr_extra=1,
-                                    min_dist=8,
-                                    max_dist=99999),
-            number_of_agents=5)
-    
+        width=15,
+        height=15,
+        rail_generator=complex_rail_generator(
+            nr_start_goal=10,
+            nr_extra=1,
+            min_dist=8,
+            max_dist=99999),
+        agent_generator=complex_rail_generator_agents_placer(),
+        number_of_agents=5)
+
     env._max_episode_steps = int(15 * (env.width + env.height))
     env_renderer = RenderTool(env)
 
@@ -52,12 +56,12 @@ def demo(args=None):
 
 
 @click.command()
-@click.option('--tests', 
+@click.option('--tests',
               type=click.Path(exists=True),
               help="Path to folder containing Flatland tests",
               required=True
               )
-@click.option('--service_id', 
+@click.option('--service_id',
               default="FLATLAND_RL_SERVICE_ID",
               help="Evaluation Service ID. This has to match the service id on the client.",
               required=False
@@ -70,14 +74,14 @@ def evaluator(tests, service_id):
         raise Exception(
             "\nRedis server does not seem to be running on your localhost.\n"
             "Please ensure that you have a redis server running on your localhost"
-            )
-    
+        )
+
     grader = FlatlandRemoteEvaluationService(
-                test_env_folder=tests,
-                flatland_rl_service_id=service_id,
-                visualize=False,
-                verbose=False
-                )
+        test_env_folder=tests,
+        flatland_rl_service_id=service_id,
+        visualize=False,
+        verbose=False
+    )
     grader.run()
 
 
diff --git a/flatland/envs/agent_generators.py b/flatland/envs/agent_generators.py
new file mode 100644
index 00000000..1f769b7d
--- /dev/null
+++ b/flatland/envs/agent_generators.py
@@ -0,0 +1,182 @@
+"""Agent generators (railway undertaking, "EVU")."""
+from typing import Tuple, List, Callable, Mapping, Optional, Any
+
+import msgpack
+import numpy as np
+
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.core.transition_map import GridTransitionMap
+from flatland.envs.agent_utils import EnvAgentStatic
+
+AgentPosition = Tuple[int, int]
+AgentGeneratorProduct = Tuple[List[AgentPosition], List[AgentPosition], List[AgentPosition], List[float]]
+AgentGenerator = Callable[[GridTransitionMap, int, Optional[Any]], AgentGeneratorProduct]
+
+
+def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, float] = None) -> List[float]:
+    """
+    Parameters
+    -------
+    nb_agents : int
+        The number of agents to generate a speed for
+    speed_ratio_map : Mapping[float,float]
+        A map of speeds mappint to their ratio of appearance. The ratios must sum up to 1.
+
+    Returns
+    -------
+    List[float]
+        A list of size nb_agents of speeds with the corresponding probabilistic ratios.
+    """
+    if speed_ratio_map is None:
+        return [1.0] * nb_agents
+
+    nb_classes = len(speed_ratio_map.keys())
+    speed_ratio_map_as_list: List[Tuple[float, float]] = list(speed_ratio_map.items())
+    speed_ratios = list(map(lambda t: t[1], speed_ratio_map_as_list))
+    speeds = list(map(lambda t: t[0], speed_ratio_map_as_list))
+    return list(map(lambda index: speeds[index], np.random.choice(nb_classes, nb_agents, p=speed_ratios)))
+
+
+def complex_rail_generator_agents_placer(speed_ratio_map: Mapping[float, float] = None) -> AgentGenerator:
+    def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None):
+        start_goal = hints['start_goal']
+        start_dir = hints['start_dir']
+        agents_position = [sg[0] for sg in start_goal[:num_agents]]
+        agents_target = [sg[1] for sg in start_goal[:num_agents]]
+        agents_direction = start_dir[:num_agents]
+
+        if speed_ratio_map:
+            speeds = speed_initialization_helper(num_agents, speed_ratio_map)
+        else:
+            speeds = [1.0] * len(agents_position)
+
+        return agents_position, agents_direction, agents_target, speeds
+
+    return generator
+
+
+def get_rnd_agents_pos_tgt_dir_on_rail(speed_ratio_map: Mapping[float, float] = None) -> AgentGenerator:
+    """
+    Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target).
+
+    Parameters
+    -------
+        rail : GridTransitionMap
+            The railway to place agents on.
+        num_agents : int
+            The number of agents to generate a speed for
+        speed_ratio_map : Mapping[float,float]
+            A map of speeds mappint to their ratio of appearance. The ratios must sum up to 1.
+    Returns
+    -------
+        Tuple[List[Tuple[int,int]], List[Tuple[int,int]], List[Tuple[int,int]], List[float]]
+        initial positions, directions, targets speeds
+    """
+
+    def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None):
+        def _path_exists(rail, start, direction, end):
+            # BFS - Check if a path exists between the 2 nodes
+
+            visited = set()
+            stack = [(start, direction)]
+            while stack:
+                node = stack.pop()
+                if node[0][0] == end[0] and node[0][1] == end[1]:
+                    return 1
+                if node not in visited:
+                    visited.add(node)
+                    moves = rail.get_transitions(node[0][0], node[0][1], node[1])
+                    for move_index in range(4):
+                        if moves[move_index]:
+                            stack.append((get_new_position(node[0], move_index),
+                                          move_index))
+
+                    # If cell is a dead-end, append previous node with reversed
+                    # orientation!
+                    nbits = 0
+                    tmp = rail.get_full_transitions(node[0][0], node[0][1])
+                    while tmp > 0:
+                        nbits += (tmp & 1)
+                        tmp = tmp >> 1
+                    if nbits == 1:
+                        stack.append((node[0], (node[1] + 2) % 4))
+
+            return 0
+
+        valid_positions = []
+        for r in range(rail.height):
+            for c in range(rail.width):
+                if rail.get_full_transitions(r, c) > 0:
+                    valid_positions.append((r, c))
+        if len(valid_positions) == 0:
+            return [], [], [], []
+        re_generate = True
+        while re_generate:
+            agents_position = [
+                valid_positions[i] for i in
+                np.random.choice(len(valid_positions), num_agents)]
+            agents_target = [
+                valid_positions[i] for i in
+                np.random.choice(len(valid_positions), num_agents)]
+
+            # agents_direction must be a direction for which a solution is
+            # guaranteed.
+            agents_direction = [0] * num_agents
+            re_generate = False
+            for i in range(num_agents):
+                valid_movements = []
+                for direction in range(4):
+                    position = agents_position[i]
+                    moves = rail.get_transitions(position[0], position[1], direction)
+                    for move_index in range(4):
+                        if moves[move_index]:
+                            valid_movements.append((direction, move_index))
+
+                valid_starting_directions = []
+                for m in valid_movements:
+                    new_position = get_new_position(agents_position[i], m[1])
+                    if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0],
+                                                                              agents_target[i]):
+                        valid_starting_directions.append(m[0])
+
+                if len(valid_starting_directions) == 0:
+                    re_generate = True
+                else:
+                    agents_direction[i] = valid_starting_directions[
+                        np.random.choice(len(valid_starting_directions), 1)[0]]
+
+        agents_speed = speed_initialization_helper(num_agents, speed_ratio_map)
+        return agents_position, agents_direction, agents_target, agents_speed
+
+    return generator
+
+
+def agents_from_file(filename) -> AgentGenerator:
+    """
+    Utility to load pickle file
+
+    Parameters
+    -------
+    input_file : Pickle file generated by env.save() or editor
+
+    Returns
+    -------
+    Tuple[List[Tuple[int,int]], List[Tuple[int,int]], List[Tuple[int,int]], List[float]]
+        initial positions, directions, targets speeds
+    """
+
+    def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None):
+        with open(filename, "rb") as file_in:
+            load_data = file_in.read()
+        data = msgpack.unpackb(load_data, use_list=False)
+
+        # agents are always reset as not moving
+        agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
+        # setup with loaded data
+        agents_position = [a.position for a in agents_static]
+        agents_direction = [a.direction for a in agents_static]
+        agents_target = [a.target for a in agents_static]
+
+        return agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
+
+    return generator
diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 79e0ac7d..380bf37f 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -1,4 +1,5 @@
-from typing import Mapping, Tuple, List, Callable
+"""Rail generators (infrastructure manager, "Infrastrukturbetreiber")."""
+from typing import Callable, Tuple, Any, Optional
 
 import msgpack
 import numpy as np
@@ -7,12 +8,12 @@ from flatland.core.grid.grid4_utils import get_direction, mirror
 from flatland.core.grid.grid_utils import distance_on_rail
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
-from flatland.envs.agent_utils import EnvAgentStatic
 from flatland.envs.grid4_generators_utils import connect_rail
-from flatland.envs.grid4_generators_utils import get_rnd_agents_pos_tgt_dir_on_rail
 
+RailGenerator = Callable[[int, int, int, int], Tuple[GridTransitionMap, Optional[Any]]]
 
-def empty_rail_generator():
+
+def empty_rail_generator() -> RailGenerator:
     """
     Returns a generator which returns an empty rail mail with no agents.
     Primarily used by the editor
@@ -24,7 +25,7 @@ def empty_rail_generator():
         rail_array = grid_map.grid
         rail_array.fill(0)
 
-        return grid_map, [], [], [], []
+        return [grid_map, None]
 
     return generator
 
@@ -33,8 +34,7 @@ def complex_rail_generator(nr_start_goal=1,
                            nr_extra=100,
                            min_dist=20,
                            max_dist=99999,
-                           seed=0,
-                           speed_initializer: Callable[[int], List[float]] = None):
+                           seed=0) -> RailGenerator:
     """
     Parameters
     -------
@@ -42,8 +42,6 @@ def complex_rail_generator(nr_start_goal=1,
         The width (number of cells) of the grid to generate.
     height : int
         The height (number of cells) of the grid to generate.
-    speed_initializer : Callable[[int], List[float]]
-        Function that returns a list of speeds for the numer of agents given as argument.
 
     Returns
     -------
@@ -56,8 +54,7 @@ def complex_rail_generator(nr_start_goal=1,
         if num_agents > nr_start_goal:
             num_agents = nr_start_goal
             print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents")
-        rail_trans = RailEnvTransitions()
-        grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
+        grid_map = GridTransitionMap(width=width, height=height, transitions=RailEnvTransitions())
         rail_array = grid_map.grid
         rail_array.fill(0)
 
@@ -81,6 +78,7 @@ def complex_rail_generator(nr_start_goal=1,
         # - return transition map + list of [start_pos, start_dir, goal_pos] points
         #
 
+        rail_trans = grid_map.transitions
         start_goal = []
         start_dir = []
         nr_created = 0
@@ -150,15 +148,10 @@ def complex_rail_generator(nr_start_goal=1,
             if len(new_path) >= 2:
                 nr_created += 1
 
-        agents_position = [sg[0] for sg in start_goal[:num_agents]]
-        agents_target = [sg[1] for sg in start_goal[:num_agents]]
-        agents_direction = start_dir[:num_agents]
-
-        if speed_initializer:
-            speeds = speed_initializer(num_agents)
-        else:
-            speeds = [1.0] * len(agents_position)
-        return grid_map, agents_position, agents_direction, agents_target, speeds
+        return grid_map, {'agents_hints': {
+            'start_goal': start_goal,
+            'start_dir': start_dir
+        }}
 
     return generator
 
@@ -202,22 +195,18 @@ def rail_from_manual_specifications_generator(rail_spec):
                 effective_transition_cell = rail_env_transitions.rotate_transition(basic_type_of_cell_, rotation_cell_)
                 rail.set_transitions((r, c), effective_transition_cell)
 
-        agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
-            rail,
-            num_agents)
-
-        return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
+        return [rail, None]
 
     return generator
 
 
-def rail_from_file(filename):
+def rail_from_file(filename) -> RailGenerator:
     """
     Utility to load pickle file
 
     Parameters
     -------
-    input_file : Pickle file generated by env.save() or editor
+    filename : Pickle file generated by env.save() or editor
 
     Returns
     -------
@@ -235,26 +224,16 @@ def rail_from_file(filename):
         grid = np.array(data[b"grid"])
         rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions)
         rail.grid = grid
-        # agents are always reset as not moving
-        agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
-        # setup with loaded data
-        agents_position = [a.position for a in agents_static]
-        agents_direction = [a.direction for a in agents_static]
-        agents_target = [a.target for a in agents_static]
         if b"distance_maps" in data.keys():
             distance_maps = data[b"distance_maps"]
             if len(distance_maps) > 0:
-                return rail, agents_position, agents_direction, agents_target, [1.0] * len(
-                    agents_position), distance_maps
-            else:
-                return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
-        else:
-            return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
+                return rail, {'distance_maps': distance_maps}
+        return [rail, None]
 
     return generator
 
 
-def rail_from_grid_transition_map(rail_map):
+def rail_from_grid_transition_map(rail_map) -> RailGenerator:
     """
     Utility to convert a rail given by a GridTransitionMap map with the correct
     16-bit transitions specifications.
@@ -271,16 +250,12 @@ def rail_from_grid_transition_map(rail_map):
     """
 
     def generator(width, height, num_agents, num_resets=0):
-        agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
-            rail_map,
-            num_agents)
-
-        return rail_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
+        return [rail_map, None]
 
     return generator
 
 
-def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
+def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGenerator:
     """
     Dummy random level generator:
     - fill in cells at random in [width-2, height-2]
@@ -544,31 +519,6 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
         return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
         return_rail.grid = tmp_rail
 
-        agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
-            return_rail,
-            num_agents)
-
-        return return_rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
+        return [return_rail, None]
 
     return generator
-
-
-def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, float]) -> List[float]:
-    """
-    Parameters
-    -------
-    nb_agents : int
-        The number of agents to generate a speed for
-    speed_ratio_map : Mapping[float,float]
-        A map of speeds mappint to their ratio of appearance. The ratios must sum up to 1.
-
-    Returns
-    -------
-    List[float]
-        A list of size nb_agents of speeds with the corresponding probabilistic ratios.
-    """
-    nb_classes = len(speed_ratio_map.keys())
-    speed_ratio_map_as_list: List[Tuple[float, float]] = list(speed_ratio_map.items())
-    speed_ratios = list(map(lambda t: t[1], speed_ratio_map_as_list))
-    speeds = list(map(lambda t: t[0], speed_ratio_map_as_list))
-    return list(map(lambda index: speeds[index], np.random.choice(nb_classes, nb_agents, p=speed_ratios)))
diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py
index dedd76b6..8bbd0df7 100644
--- a/flatland/envs/grid4_generators_utils.py
+++ b/flatland/envs/grid4_generators_utils.py
@@ -5,10 +5,8 @@ Generator functions are functions that take width, height and num_resets as argu
 a GridTransitionMap object.
 """
 
-import numpy as np
-
 from flatland.core.grid.grid4_astar import a_star
-from flatland.core.grid.grid4_utils import get_direction, mirror, get_new_position
+from flatland.core.grid.grid4_utils import get_direction, mirror
 
 
 def connect_rail(rail_trans, rail_array, start, end):
@@ -55,81 +53,3 @@ def connect_rail(rail_trans, rail_array, start, end):
 
         current_dir = new_dir
     return path
-
-
-def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
-    """
-    Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target).
-
-    TODO: add extensive documentation, as users may need this function to simplify their custom level generators.
-    """
-
-    def _path_exists(rail, start, direction, end):
-        # BFS - Check if a path exists between the 2 nodes
-
-        visited = set()
-        stack = [(start, direction)]
-        while stack:
-            node = stack.pop()
-            if node[0][0] == end[0] and node[0][1] == end[1]:
-                return 1
-            if node not in visited:
-                visited.add(node)
-                moves = rail.get_transitions(node[0][0], node[0][1], node[1])
-                for move_index in range(4):
-                    if moves[move_index]:
-                        stack.append((get_new_position(node[0], move_index),
-                                      move_index))
-
-                # If cell is a dead-end, append previous node with reversed
-                # orientation!
-                nbits = 0
-                tmp = rail.get_full_transitions(node[0][0], node[0][1])
-                while tmp > 0:
-                    nbits += (tmp & 1)
-                    tmp = tmp >> 1
-                if nbits == 1:
-                    stack.append((node[0], (node[1] + 2) % 4))
-
-        return 0
-
-    valid_positions = []
-    for r in range(rail.height):
-        for c in range(rail.width):
-            if rail.get_full_transitions(r, c) > 0:
-                valid_positions.append((r, c))
-
-    re_generate = True
-    while re_generate:
-        agents_position = [
-            valid_positions[i] for i in
-            np.random.choice(len(valid_positions), num_agents)]
-        agents_target = [
-            valid_positions[i] for i in
-            np.random.choice(len(valid_positions), num_agents)]
-
-        # agents_direction must be a direction for which a solution is
-        # guaranteed.
-        agents_direction = [0] * num_agents
-        re_generate = False
-        for i in range(num_agents):
-            valid_movements = []
-            for direction in range(4):
-                position = agents_position[i]
-                moves = rail.get_transitions(position[0], position[1], direction)
-                for move_index in range(4):
-                    if moves[move_index]:
-                        valid_movements.append((direction, move_index))
-
-            valid_starting_directions = []
-            for m in valid_movements:
-                new_position = get_new_position(agents_position[i], m[1])
-                if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0], agents_target[i]):
-                    valid_starting_directions.append(m[0])
-
-            if len(valid_starting_directions) == 0:
-                re_generate = True
-            else:
-                agents_direction[i] = valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]]
-
-    return agents_position, agents_direction, agents_target
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 22812829..27664403 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -11,8 +11,9 @@ import numpy as np
 
 from flatland.core.env import Environment
 from flatland.core.grid.grid4_utils import get_new_position
+from flatland.envs.agent_generators import get_rnd_agents_pos_tgt_dir_on_rail, AgentGenerator
 from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
-from flatland.envs.generators import random_rail_generator
+from flatland.envs.generators import random_rail_generator, RailGenerator
 from flatland.envs.observations import TreeObsForRailEnv
 
 m.patch()
@@ -91,7 +92,8 @@ class RailEnv(Environment):
     def __init__(self,
                  width,
                  height,
-                 rail_generator=random_rail_generator(),
+                 rail_generator: RailGenerator = random_rail_generator(),
+                 agent_generator: AgentGenerator = get_rnd_agents_pos_tgt_dir_on_rail(),
                  number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2),
                  max_episode_steps=None,
@@ -107,13 +109,12 @@ class RailEnv(Environment):
             height and agents handles of a  rail environment, along with the number of times
             the env has been reset, and returns a GridTransitionMap object and a list of
             starting positions, targets, and initial orientations for agent handle.
-            Implemented functions are:
-                random_rail_generator : generate a random rail of given size
-                rail_from_grid_transition_map(rail_map) : generate a rail from
-                                        a GridTransitionMap object
-                rail_from_manual_sp ecifications_generator(rail_spec) : generate a rail from
-                                        a rail specifications array
-                TODO: generate_rail_from_saved_list or from list of ndarray bitmaps ---
+            The rail_generator can pass a distance map in the hints or information for specific agent_generators.
+            Implementations can be found in flatland/envs/generators.py
+        agent_generator : function
+            The agent_generator function is a function that takes the grid, the number of agents and optional hints
+            and returns a list of starting positions, targets, initial orientations and speed for all agent handles.
+            Implementations can be found in flatland/envs/agent_generators.py
         width : int
             The width of the rail map. Potentially in the future,
             a range of widths to sample from.
@@ -131,7 +132,8 @@ class RailEnv(Environment):
         file_name: you can load a pickle file.
         """
 
-        self.rail_generator = rail_generator
+        self.rail_generator: RailGenerator = rail_generator
+        self.agent_generator: AgentGenerator = agent_generator
         self.rail = None
         self.width = width
         self.height = height
@@ -213,18 +215,21 @@ class RailEnv(Environment):
             if replace_agents then regenerate the agents static.
             Relies on the rail_generator returning agent_static lists (pos, dir, target)
         """
-        tRailAgents = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
+        rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
 
-        # Check if generator provided a distance map TODO: Make this check safer!
-        if len(tRailAgents) > 5:
-            self.obs_builder.distance_map = tRailAgents[-1]
+        if optionals and 'distance_maps' in optionals:
+            self.obs_builder.distance_map = optionals['distance_maps']
 
         if regen_rail or self.rail is None:
-            self.rail = tRailAgents[0]
+            self.rail = rail
             self.height, self.width = self.rail.grid.shape
 
         if replace_agents:
-            self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5])
+            agents_hints = None
+            if optionals and 'agents_hints' in optionals:
+                agents_hints = optionals['agents_hints']
+            self.agents_static = EnvAgentStatic.from_lists(
+                *self.agent_generator(self.rail, self.get_num_agents(), hints=agents_hints))
 
         self.restart_agents()
 
diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py
index 12e0c092..dbeb2fb4 100644
--- a/tests/test_distance_map.py
+++ b/tests/test_distance_map.py
@@ -2,6 +2,7 @@ import numpy as np
 
 from flatland.core.grid.grid4 import Grid4Transitions
 from flatland.core.transition_map import GridTransitionMap
+from flatland.envs.agent_generators import get_rnd_agents_pos_tgt_dir_on_rail
 from flatland.envs.generators import rail_from_grid_transition_map
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
@@ -34,6 +35,7 @@ def test_walker():
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2,
                                                        predictor=ShortestPathPredictorForRailEnv(max_depth=10)),
diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py
index 574705c4..c0e28534 100644
--- a/tests/test_flatland_envs_observations.py
+++ b/tests/test_flatland_envs_observations.py
@@ -4,6 +4,7 @@
 import numpy as np
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.envs.agent_generators import get_rnd_agents_pos_tgt_dir_on_rail
 from flatland.envs.agent_utils import EnvAgent
 from flatland.envs.generators import rail_from_grid_transition_map
 from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
@@ -21,6 +22,7 @@ def test_global_obs():
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
                   number_of_agents=1,
                   obs_builder_object=GlobalObsForRailEnv())
 
@@ -90,6 +92,7 @@ def test_reward_function_conflict(rendering=False):
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
                   number_of_agents=2,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
@@ -168,6 +171,7 @@ def test_reward_function_waiting(rendering=False):
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
                   number_of_agents=2,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py
index 7acd58ed..8c67a42a 100644
--- a/tests/test_flatland_envs_predictions.py
+++ b/tests/test_flatland_envs_predictions.py
@@ -5,6 +5,7 @@ import pprint
 import numpy as np
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.envs.agent_generators import get_rnd_agents_pos_tgt_dir_on_rail
 from flatland.envs.generators import rail_from_grid_transition_map
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
@@ -21,6 +22,7 @@ def test_dummy_predictor(rendering=False):
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
                   )
@@ -111,6 +113,7 @@ def test_shortest_path_predictor(rendering=False):
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
@@ -230,6 +233,7 @@ def test_shortest_path_predictor_conflicts(rendering=False):
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
                   number_of_agents=2,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py
index 71dc87ce..9108d614 100644
--- a/tests/test_flatland_envs_rail_env.py
+++ b/tests/test_flatland_envs_rail_env.py
@@ -5,6 +5,7 @@ import numpy as np
 from flatland.core.grid.grid4 import Grid4Transitions
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
+from flatland.envs.agent_generators import get_rnd_agents_pos_tgt_dir_on_rail, complex_rail_generator_agents_placer
 from flatland.envs.agent_utils import EnvAgent
 from flatland.envs.agent_utils import EnvAgentStatic
 from flatland.envs.generators import complex_rail_generator
@@ -27,6 +28,7 @@ def test_load_env():
 def test_save_load():
     env = RailEnv(width=10, height=10,
                   rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=0),
+                  agent_generator=complex_rail_generator_agents_placer(),
                   number_of_agents=2)
     env.reset()
     agent_1_pos = env.agents_static[0].position
@@ -86,6 +88,7 @@ def test_rail_environment_single_agent():
     rail_env = RailEnv(width=3,
                        height=3,
                        rail_generator=rail_from_grid_transition_map(rail),
+                       agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
                        number_of_agents=1,
                        obs_builder_object=GlobalObsForRailEnv())
 
@@ -165,6 +168,7 @@ def test_dead_end():
     rail_env = RailEnv(width=rail_map.shape[1],
                        height=rail_map.shape[0],
                        rail_generator=rail_from_grid_transition_map(rail),
+                       agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
                        number_of_agents=1,
                        obs_builder_object=GlobalObsForRailEnv())
 
@@ -209,6 +213,7 @@ def test_dead_end():
     rail_env = RailEnv(width=rail_map.shape[1],
                        height=rail_map.shape[0],
                        rail_generator=rail_from_grid_transition_map(rail),
+                       agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
                        number_of_agents=1,
                        obs_builder_object=GlobalObsForRailEnv())
 
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index 67dcd25c..60a15bb7 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -1,5 +1,6 @@
 import numpy as np
 
+from flatland.envs.agent_generators import complex_rail_generator_agents_placer
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
@@ -62,6 +63,7 @@ def test_malfunction_process():
                   height=20,
                   rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
                                                         seed=0),
+                  agent_generator=complex_rail_generator_agents_placer(),
                   number_of_agents=2,
                   obs_builder_object=SingleAgentNavigationObs(),
                   stochastic_data=stochastic_data)
diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py
index 47aadee7..8703800e 100644
--- a/tests/test_multi_speed.py
+++ b/tests/test_multi_speed.py
@@ -1,10 +1,12 @@
 import numpy as np
 
+from flatland.envs.agent_generators import complex_rail_generator_agents_placer
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.rail_env import RailEnv
 
 np.random.seed(1)
 
+
 # Use the complex_rail_generator to generate feasible network configurations with corresponding tasks
 # Training on simple small tasks is the best way to get familiar with the environment
 #
@@ -46,6 +48,7 @@ def test_multi_speed_init():
                   height=50,
                   rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999,
                                                         seed=0),
+                  agent_generator=complex_rail_generator_agents_placer(),
                   number_of_agents=5)
     # Initialize the agent with the parameters corresponding to the environment and observation_builder
     agent = RandomAgent(218, 4)
diff --git a/tests/test_speed_classes.py b/tests/test_speed_classes.py
index 6ef600d9..67054c81 100644
--- a/tests/test_speed_classes.py
+++ b/tests/test_speed_classes.py
@@ -1,7 +1,8 @@
 """Test speed initialization by a map of speeds and their corresponding ratios."""
 import numpy as np
 
-from flatland.envs.generators import speed_initialization_helper, complex_rail_generator
+from flatland.envs.agent_generators import speed_initialization_helper, complex_rail_generator_agents_placer
+from flatland.envs.generators import complex_rail_generator
 from flatland.envs.rail_env import RailEnv
 
 
@@ -17,13 +18,11 @@ def test_speed_initialization_helper():
 def test_rail_env_speed_intializer():
     speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.1, 5: 0.2}
 
-    def my_speed_initializer(nb_agents):
-        return speed_initialization_helper(nb_agents, speed_ratio_map)
-
     env = RailEnv(width=50,
                   height=50,
                   rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999,
-                                                        seed=0, speed_initializer=my_speed_initializer),
+                                                        seed=0),
+                  agent_generator=complex_rail_generator_agents_placer(),
                   number_of_agents=10)
     env.reset()
     actual_speeds = list(map(lambda agent: agent.speed_data['speed'], env.agents))
diff --git a/tests/tests_generators.py b/tests/tests_generators.py
index f97b071e..46109c55 100644
--- a/tests/tests_generators.py
+++ b/tests/tests_generators.py
@@ -3,6 +3,8 @@
 
 import numpy as np
 
+from flatland.envs.agent_generators import get_rnd_agents_pos_tgt_dir_on_rail, complex_rail_generator_agents_placer, \
+    agents_from_file
 from flatland.envs.generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \
     random_rail_generator, empty_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
@@ -58,7 +60,8 @@ def test_complex_rail_generator():
     env = RailEnv(width=x_dim,
                   height=y_dim,
                   number_of_agents=n_agents,
-                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist)
+                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist),
+                  agent_generator=complex_rail_generator_agents_placer()
                   )
     assert env.get_num_agents() == 2
     assert env.rail.grid.shape == (y_dim, x_dim)
@@ -69,7 +72,8 @@ def test_complex_rail_generator():
     env = RailEnv(width=x_dim,
                   height=y_dim,
                   number_of_agents=n_agents,
-                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist)
+                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist),
+                  agent_generator=complex_rail_generator_agents_placer()
                   )
     assert env.get_num_agents() == 0
     assert env.rail.grid.shape == (y_dim, x_dim)
@@ -82,7 +86,8 @@ def test_complex_rail_generator():
     env = RailEnv(width=x_dim,
                   height=y_dim,
                   number_of_agents=n_agents,
-                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist)
+                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist),
+                  agent_generator=complex_rail_generator_agents_placer()
                   )
     assert env.get_num_agents() == n_agents
     assert env.rail.grid.shape == (y_dim, x_dim)
@@ -94,6 +99,7 @@ def test_rail_from_grid_transition_map():
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
                   number_of_agents=n_agents
                   )
     nr_rail_elements = np.count_nonzero(env.rail.grid)
@@ -118,6 +124,7 @@ def tests_rail_from_file():
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
+                  agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
                   number_of_agents=3,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
@@ -130,6 +137,7 @@ def tests_rail_from_file():
     env = RailEnv(width=1,
                   height=1,
                   rail_generator=rail_from_file(file_name),
+                  agent_generator=agents_from_file(file_name),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
@@ -151,6 +159,7 @@ def tests_rail_from_file():
     env2 = RailEnv(width=rail_map.shape[1],
                    height=rail_map.shape[0],
                    rail_generator=rail_from_grid_transition_map(rail),
+                   agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
                    number_of_agents=3,
                    obs_builder_object=GlobalObsForRailEnv(),
                    )
@@ -164,6 +173,7 @@ def tests_rail_from_file():
     env2 = RailEnv(width=1,
                    height=1,
                    rail_generator=rail_from_file(file_name_2),
+                   agent_generator=agents_from_file(file_name_2),
                    number_of_agents=1,
                    obs_builder_object=GlobalObsForRailEnv(),
                    )
@@ -180,6 +190,7 @@ def tests_rail_from_file():
     env3 = RailEnv(width=1,
                    height=1,
                    rail_generator=rail_from_file(file_name),
+                   agent_generator=agents_from_file(file_name),
                    number_of_agents=1,
                    obs_builder_object=GlobalObsForRailEnv(),
                    )
@@ -197,6 +208,7 @@ def tests_rail_from_file():
     env4 = RailEnv(width=1,
                    height=1,
                    rail_generator=rail_from_file(file_name_2),
+                   agent_generator=agents_from_file(file_name_2),
                    number_of_agents=1,
                    obs_builder_object=TreeObsForRailEnv(max_depth=2),
                    )
-- 
GitLab