diff --git a/examples/introduction_flatland_2_1.py b/examples/introduction_flatland_2_1.py
index 8ea267c18accf0a28f964bff5aecdbc70b50016e..d770fa43ffee5a1a8e4231a8a6a2d9efe7891746 100644
--- a/examples/introduction_flatland_2_1.py
+++ b/examples/introduction_flatland_2_1.py
@@ -4,12 +4,14 @@ import os
 # In Flatland you can use custom observation builders and predicitors
 # Observation builders generate the observation needed by the controller
 # Preditctors can be used to do short time prediction which can help in avoiding conflicts in the network
-from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
+from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters, ParamMalfunctionGen
+
 from flatland.envs.observations import GlobalObsForRailEnv
 # First of all we import the Flatland rail environment
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_env import RailEnvActions
 from flatland.envs.rail_generators import sparse_rail_generator
+#from flatland.envs.sparse_rail_gen import SparseRailGen
 from flatland.envs.schedule_generators import sparse_schedule_generator
 # We also include a renderer because we want to visualize what is going on in the environment
 from flatland.utils.rendertools import RenderTool, AgentRenderVariant
@@ -46,6 +48,14 @@ rail_generator = sparse_rail_generator(max_num_cities=cities_in_map,
                                        max_rails_in_city=max_rail_in_cities,
                                        )
 
+#rail_generator = SparseRailGen(max_num_cities=cities_in_map,
+#                                       seed=seed,
+#                                       grid_mode=grid_distribution_of_cities,
+#                                       max_rails_between_cities=max_rails_between_cities,
+#                                       max_rails_in_city=max_rail_in_cities,
+#                                       )
+
+
 # The schedule generator can make very basic schedules with a start point, end point and a speed profile for each agent.
 # The speed profiles can be adjusted directly as well as shown later on. We start by introducing a statistical
 # distribution of speed profiles
@@ -80,7 +90,8 @@ env = RailEnv(width=width,
               schedule_generator=schedule_generator,
               number_of_agents=nr_trains,
               obs_builder_object=observation_builder,
-              malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
+              #malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
+              malfunction_generator=ParamMalfunctionGen(stochastic_data),
               remove_agents_at_target=True)
 env.reset()
 
diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py
index 2f27eca3652b89866432067ec50dda7ab6bdf2ff..62a920805d2b530c5da2f26b9289e3a84e15cc02 100644
--- a/flatland/envs/malfunction_generators.py
+++ b/flatland/envs/malfunction_generators.py
@@ -143,6 +143,44 @@ def malfunction_from_params(parameters: MalfunctionParameters) -> Tuple[Malfunct
                                              max_number_of_steps_broken)
 
 
+class ParamMalfunctionGen(object):
+    def __init__(self, parameters: MalfunctionParameters):
+        self.mean_malfunction_rate = parameters.malfunction_rate
+        self.min_number_of_steps_broken = parameters.min_duration
+        self.max_number_of_steps_broken = parameters.max_duration
+
+    def generate(self, agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
+      
+        # Dummy reset function as we don't implement specific seeding here
+        if reset:
+            return Malfunction(0)
+
+        if agent.malfunction_data['malfunction'] < 1:
+            if np_random.rand() < _malfunction_prob(self.mean_malfunction_rate):
+                num_broken_steps = np_random.randint(self.min_number_of_steps_broken,
+                                                     self.max_number_of_steps_broken + 1) + 1
+                return Malfunction(num_broken_steps)
+        return Malfunction(0)
+
+    def get_process_data(self):
+        return MalfunctionProcessData(
+            self.mean_malfunction_rate, 
+            self.min_number_of_steps_broken,
+            self.max_number_of_steps_broken)
+
+
+class NoMalfunctionGen(ParamMalfunctionGen):
+    def __init__(self):
+        self.mean_malfunction_rate = 0.
+        self.min_number_of_steps_broken = 0
+        self.max_number_of_steps_broken = 0
+
+    def generate(self, agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
+        return Malfunction(0)
+
+    
+
+
 def no_malfunction_generator() -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
     """
     Malfunction generator which generates no malfunctions
@@ -169,6 +207,7 @@ def no_malfunction_generator() -> Tuple[MalfunctionGenerator, MalfunctionProcess
                                              max_number_of_steps_broken)
 
 
+
 def single_malfunction_generator(earlierst_malfunction: int, malfunction_duration: int) -> Tuple[
     MalfunctionGenerator, MalfunctionProcessData]:
     """
diff --git a/flatland/envs/persistence.py b/flatland/envs/persistence.py
index 3952b7944449cb37a5135a54794f81d9a2810104..8ed7ec8ab27dbc08746bd2a774dd130fc8eefdce 100644
--- a/flatland/envs/persistence.py
+++ b/flatland/envs/persistence.py
@@ -206,7 +206,7 @@ class RailEnvPersister(object):
         # msgpack cannot persist EnvAgent so use the Agent namedtuple.
         agent_data = [agent.to_agent() for agent in env.agents]
         #print("get_full_state - agent_data:", agent_data)
-        malfunction_data: MalfunctionProcessData = env.malfunction_process_data
+        malfunction_data: mal_gen.MalfunctionProcessData = env.malfunction_process_data
 
         msg_data_dict = {
             "grid": grid_data,
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 30ff30eb80323afa08bb73b60c781dbd2cd5d153..03678dec6ebc7ae8f317357ac4931e1fab06cd84 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -135,6 +135,7 @@ class RailEnv(Environment):
                  number_of_agents=1,
                  obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(),
                  malfunction_generator_and_process_data=None, #mal_gen.no_malfunction_generator(),
+                 malfunction_generator=None,
                  remove_agents_at_target=True,
                  random_seed=1,
                  record_steps=False
@@ -176,9 +177,19 @@ class RailEnv(Environment):
         """
         super().__init__()
 
-        if malfunction_generator_and_process_data is None:
-            malfunction_generator_and_process_data = mal_gen.no_malfunction_generator()
-        self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
+        if malfunction_generator_and_process_data is not None:
+            print("DEPRECATED - RailEnv arg: malfunction_and_process_data - use malfunction_generator")
+            self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
+        elif malfunction_generator is not None:
+            self.malfunction_generator = malfunction_generator
+            # malfunction_process_data is not used
+            #self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
+            self.malfunction_process_data = self.malfunction_generator.get_process_data()
+        # replace default values here because we can't use default args values because of cyclic imports
+        else:
+            self.malfunction_generator = mal_gen.NoMalfunctionGen()
+            self.malfunction_process_data = self.malfunction_generator.get_process_data()
+
         #self.rail_generator: RailGenerator = rail_generator
         if rail_generator is None:
             rail_generator = rail_gen.random_rail_generator()
@@ -315,8 +326,16 @@ class RailEnv(Environment):
 
         optionals = {}
         if regenerate_rail or self.rail is None:
-            rail, optionals = self.rail_generator(self.width, self.height, self.number_of_agents, self.num_resets,
-                                                  self.np_random)
+
+            if "__call__" in dir(self.rail_generator):
+                rail, optionals = self.rail_generator(
+                    self.width, self.height, self.number_of_agents, self.num_resets, self.np_random)
+            elif "generate" in dir(self.rail_generator):
+                rail, optionals = self.rail_generator.generate(
+                    self.width, self.height, self.number_of_agents, self.num_resets, self.np_random)
+            else:
+                raise ValueError("Could not invoke __call__ or generate on rail_generator")
+
 
             self.rail = rail
             self.height, self.width = self.rail.grid.shape
@@ -373,7 +392,10 @@ class RailEnv(Environment):
         self.distance_map.reset(self.agents, self.rail)
 
         # Reset the malfunction generator
-        self.malfunction_generator(reset=True)
+        if "generate" in dir(self.malfunction_generator):
+            self.malfunction_generator.generate(reset=True)
+        else:
+            self.malfunction_generator(reset=True)
 
         # Empty the episode store of agent positions
         self.cur_episode = []
@@ -424,7 +446,12 @@ class RailEnv(Environment):
 
         """
 
-        malfunction: Malfunction = self.malfunction_generator(agent, self.np_random)
+        #malfunction: Malfunction = self.malfunction_generator(agent, self.np_random)
+        if "generate" in dir(self.malfunction_generator):
+            malfunction: mal_gen.Malfunction = self.malfunction_generator.generate(agent, self.np_random)
+        else:
+            malfunction: mal_gen.Malfunction = self.malfunction_generator(agent, self.np_random)
+
         if malfunction.num_broken_steps > 0:
             agent.malfunction_data['malfunction'] = malfunction.num_broken_steps
             agent.malfunction_data['moving_before_malfunction'] = agent.moving
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 1a73acb93908fda477be25637631a465df8b9183..9232914c41c9a1e4c54c625a6f84ae5912d8642e 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -560,31 +560,51 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11, seed=1) -> R
     return generator
 
 
-def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_rails_between_cities: int = 4,
+
+class RailGen(object):
+    def __init__(self):
+        pass
+
+    def generate(self):
+        pass
+
+
+def sparse_rail_generator(*args, **kwargs):
+    return SparseRailGen(*args, **kwargs)
+
+class SparseRailGen(RailGen):
+
+    def __init__(self, max_num_cities: int = 5, grid_mode: bool = False, max_rails_between_cities: int = 4,
                           max_rails_in_city: int = 4, seed=0) -> RailGenerator:
-    """
-    Generates railway networks with cities and inner city rails
+        """
+        Generates railway networks with cities and inner city rails
 
-    Parameters
-    ----------
-    max_num_cities : int
-        Max number of cities to build. The generator tries to achieve this numbers given all the parameters
-    grid_mode: Bool
-        How to distribute the cities in the path, either equally in a grid or random
-    max_rails_between_cities: int
-        Max number of rails connecting to a city. This is only the number of connection points at city boarder.
-        Number of tracks drawn inbetween cities can still vary
-    max_rails_in_city: int
-        Number of parallel tracks in the city. This represents the number of tracks in the trainstations
-    seed: int
-        Initiate the seed
+        Parameters
+        ----------
+        max_num_cities : int
+            Max number of cities to build. The generator tries to achieve this numbers given all the parameters
+        grid_mode: Bool
+            How to distribute the cities in the path, either equally in a grid or random
+        max_rails_between_cities: int
+            Max number of rails connecting to a city. This is only the number of connection points at city boarder.
+            Number of tracks drawn inbetween cities can still vary
+        max_rails_in_city: int
+            Number of parallel tracks in the city. This represents the number of tracks in the trainstations
+        seed: int
+            Initiate the seed
 
-    Returns
-    -------
-    Returns the rail generator object to the rail env constructor
-    """
+        Returns
+        -------
+        Returns the rail generator object to the rail env constructor
+        """
+        self.max_num_cities = max_num_cities
+        self.grid_mode = grid_mode
+        self.max_rails_between_cities = max_rails_between_cities
+        self.max_rails_in_city = max_rails_in_city
+        self.seed = seed # TODO: seed in constructor or generate?
 
-    def generator(width: int, height: int, num_agents: int, num_resets: int = 0,
+
+    def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
                   np_random: RandomState = None) -> RailGenerator:
         """
 
@@ -616,28 +636,28 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
         # We add 2 cells to avoid that track lenght is to short
         city_padding = 2
         # We use ceil if we get uneven numbers of city radius. This is to guarantee that all rails fit within the city.
-        city_radius = int(np.ceil((max_rails_in_city) / 2)) + city_padding
+        city_radius = int(np.ceil((self.max_rails_in_city) / 2)) + city_padding
         vector_field = np.zeros(shape=(height, width)) - 1.
 
         min_nr_rails_in_city = 2
-        rails_in_city = min_nr_rails_in_city if max_rails_in_city < min_nr_rails_in_city else max_rails_in_city
-        rails_between_cities = rails_in_city if max_rails_between_cities > rails_in_city else max_rails_between_cities
+        rails_in_city = min_nr_rails_in_city if self.max_rails_in_city < min_nr_rails_in_city else self.max_rails_in_city
+        rails_between_cities = rails_in_city if self.max_rails_between_cities > rails_in_city else self.max_rails_between_cities
 
         # Calculate the max number of cities allowed
         # and reduce the number of cities to build to avoid problems
-        max_feasible_cities = min(max_num_cities,
+        max_feasible_cities = min(self.max_num_cities,
                                   ((height - 2) // (2 * (city_radius + 1))) * ((width - 2) // (2 * (city_radius + 1))))
         if max_feasible_cities < 2:
             # sys.exit("[ABORT] Cannot fit more than one city in this map, no feasible environment possible! Aborting.")
             raise ValueError("ERROR: Cannot fit more than one city in this map, no feasible environment possible!")
 
         # Evenly distribute cities
-        if grid_mode:
-            city_positions = _generate_evenly_distr_city_positions(max_feasible_cities, city_radius, width,
+        if self.grid_mode:
+            city_positions = self._generate_evenly_distr_city_positions(max_feasible_cities, city_radius, width,
                                                                    height)
         # Distribute cities randomlz
         else:
-            city_positions = _generate_random_city_positions(max_feasible_cities, city_radius, width, height,
+            city_positions = self._generate_random_city_positions(max_feasible_cities, city_radius, width, height,
                                                              np_random=np_random)
 
         # reduce num_cities if less were generated in random mode
@@ -645,31 +665,31 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
         # If random generation failed just put the cities evenly
         if num_cities < 2:
             warnings.warn("[WARNING] Changing to Grid mode to place at least 2 cities.")
-            city_positions = _generate_evenly_distr_city_positions(max_feasible_cities, city_radius, width,
+            city_positions = self._generate_evenly_distr_city_positions(max_feasible_cities, city_radius, width,
                                                                    height)
         num_cities = len(city_positions)
 
         # Set up connection points for all cities
         inner_connection_points, outer_connection_points, city_orientations, city_cells = \
-            _generate_city_connection_points(
+            self._generate_city_connection_points(
                 city_positions, city_radius, vector_field, rails_between_cities,
                 rails_in_city, np_random=np_random)
 
         # Connect the cities through the connection points
-        inter_city_lines = _connect_cities(city_positions, outer_connection_points, city_cells,
+        inter_city_lines = self._connect_cities(city_positions, outer_connection_points, city_cells,
                                            rail_trans, grid_map)
 
         # Build inner cities
-        free_rails = _build_inner_cities(city_positions, inner_connection_points,
+        free_rails = self._build_inner_cities(city_positions, inner_connection_points,
                                          outer_connection_points,
                                          rail_trans,
                                          grid_map)
 
         # Populate cities
-        train_stations = _set_trainstation_positions(city_positions, city_radius, free_rails)
+        train_stations = self._set_trainstation_positions(city_positions, city_radius, free_rails)
 
         # Fix all transition elements
-        _fix_transitions(city_cells, inter_city_lines, grid_map, vector_field)
+        self._fix_transitions(city_cells, inter_city_lines, grid_map, vector_field)
 
         return grid_map, {'agents_hints': {
             'num_agents': num_agents,
@@ -678,7 +698,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
             'city_orientations': city_orientations
         }}
 
-    def _generate_random_city_positions(num_cities: int, city_radius: int, width: int,
+    def _generate_random_city_positions(self, num_cities: int, city_radius: int, width: int,
                                         height: int, np_random: RandomState = None) -> (
         IntVector2DArray, IntVector2DArray):
         """
@@ -713,7 +733,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
                 too_close = False
                 # Check distance to cities
                 for city_pos in city_positions:
-                    if _are_cities_overlapping((row, col), city_pos, 2 * (city_radius + 1) + 1):
+                    if self.__class__._are_cities_overlapping((row, col), city_pos, 2 * (city_radius + 1) + 1):
                         too_close = True
 
                 if not too_close:
@@ -726,7 +746,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
                     break
         return city_positions
 
-    def _generate_evenly_distr_city_positions(num_cities: int, city_radius: int, width: int, height: int
+    def _generate_evenly_distr_city_positions(self, num_cities: int, city_radius: int, width: int, height: int
                                               ) -> (IntVector2DArray, IntVector2DArray):
         """
         Distribute the cities in an evenly spaced grid
@@ -773,7 +793,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
             city_positions.append((row, col))
         return city_positions
 
-    def _generate_city_connection_points(city_positions: IntVector2DArray, city_radius: int,
+    def _generate_city_connection_points(self, city_positions: IntVector2DArray, city_radius: int,
                                          vector_field: IntVector2DArray, rails_between_cities: int,
                                          rails_in_city: int = 2, np_random: RandomState = None) -> (
         List[List[List[IntVector2D]]],
@@ -824,19 +844,19 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
             neighb_dist = []
             for neighbour_city in city_positions:
                 neighb_dist.append(Vec2dOperations.get_manhattan_distance(city_position, neighbour_city))
-            closest_neighb_idx = argsort(neighb_dist)
+            closest_neighb_idx = self.__class__.argsort(neighb_dist)
 
             # Store the directions to these neighbours and orient city to face closest neighbour
             connection_sides_idx = []
             idx = 1
-            if grid_mode:
+            if self.grid_mode:
                 current_closest_direction = np_random.randint(4)
             else:
                 current_closest_direction = direction_to_point(city_position, city_positions[closest_neighb_idx[idx]])
             connection_sides_idx.append(current_closest_direction)
             connection_sides_idx.append((current_closest_direction + 2) % 4)
             city_orientations.append(current_closest_direction)
-            city_cells.extend(_get_cells_in_city(city_position, city_radius, city_orientations[-1], vector_field))
+            city_cells.extend(self._get_cells_in_city(city_position, city_radius, city_orientations[-1], vector_field))
             # set the number of tracks within a city, at least 2 tracks per city
             connections_per_direction = np.zeros(4, dtype=int)
             nr_of_connection_points = np_random.randint(2, rails_in_city + 1)
@@ -886,7 +906,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
             outer_connection_points.append(connection_points_coordinates_outer)
         return inner_connection_points, outer_connection_points, city_orientations, city_cells
 
-    def _connect_cities(city_positions: IntVector2DArray, connection_points: List[List[List[IntVector2D]]],
+    def _connect_cities(self, city_positions: IntVector2DArray, connection_points: List[List[List[IntVector2D]]],
                         city_cells: IntVector2DArray,
                         rail_trans: RailEnvTransitions, grid_map: RailEnvTransitions) -> List[IntVector2DArray]:
         """
@@ -918,10 +938,10 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
                             Grid4TransitionsEnum.WEST]
 
         for current_city_idx in np.arange(len(city_positions)):
-            closest_neighbours = _closest_neighbour_in_grid4_directions(current_city_idx, city_positions)
+            closest_neighbours = self._closest_neighbour_in_grid4_directions(current_city_idx, city_positions)
             for out_direction in grid4_directions:
 
-                neighbour_idx = get_closest_neighbour_for_direction(closest_neighbours, out_direction)
+                neighbour_idx = self.get_closest_neighbour_for_direction(closest_neighbours, out_direction)
 
                 for city_out_connection_point in connection_points[current_city_idx][out_direction]:
 
@@ -944,7 +964,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
 
         return all_paths
 
-    def get_closest_neighbour_for_direction(closest_neighbours, out_direction):
+    def get_closest_neighbour_for_direction(self, closest_neighbours, out_direction):
         """
         Given a list of clostest neighbours in each direction this returns the city index of the neighbor in a given
         direction. Direction is a 90 degree cone facing the desired directiont.
@@ -981,7 +1001,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
 
         return closest_neighbours[(out_direction + 2) % 4]  # clockwise
 
-    def _build_inner_cities(city_positions: IntVector2DArray, inner_connection_points: List[List[List[IntVector2D]]],
+    def _build_inner_cities(self, city_positions: IntVector2DArray, inner_connection_points: List[List[List[IntVector2D]]],
                             outer_connection_points: List[List[List[IntVector2D]]], rail_trans: RailEnvTransitions,
                             grid_map: GridTransitionMap) -> (List[IntVector2DArray], List[List[List[IntVector2D]]]):
         """
@@ -1051,7 +1071,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
                     connect_straight_line_in_grid_map(grid_map, target, target_outer, rail_trans)
         return free_rails
 
-    def _set_trainstation_positions(city_positions: IntVector2DArray, city_radius: int,
+    def _set_trainstation_positions(self, city_positions: IntVector2DArray, city_radius: int,
                                     free_rails: List[List[List[IntVector2D]]]) -> List[List[Tuple[IntVector2D, int]]]:
         """
         Populate the cities with possible start and end positions. Trainstations are set on the center of each paralell
@@ -1080,7 +1100,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
                 train_stations[current_city].append((possible_location, track_nbr))
         return train_stations
 
-    def _fix_transitions(city_cells: IntVector2DArray, inter_city_lines: List[IntVector2DArray],
+    def _fix_transitions(self, city_cells: IntVector2DArray, inter_city_lines: List[IntVector2DArray],
                          grid_map: GridTransitionMap, vector_field):
         """
         Check and fix transitions of all the cells that were modified. This is necessary because we ignore validity
@@ -1117,7 +1137,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
         for cell in range(rails_to_fix_cnt):
             grid_map.fix_transitions((rails_to_fix[3 * cell], rails_to_fix[3 * cell + 1]), rails_to_fix[3 * cell + 2])
 
-    def _closest_neighbour_in_grid4_directions(current_city_idx: int, city_positions: IntVector2DArray) -> List[int]:
+    def _closest_neighbour_in_grid4_directions(self, current_city_idx: int, city_positions: IntVector2DArray) -> List[int]:
         """
         Finds the closest city in each direction of the current city
         Parameters
@@ -1152,6 +1172,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
 
         return closest_neighbour
 
+    @staticmethod
     def argsort(seq):
         """
         Same as Numpy sort but for lists
@@ -1168,7 +1189,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
         # http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python
         return sorted(range(len(seq)), key=seq.__getitem__)
 
-    def _get_cells_in_city(center: IntVector2D, radius: int, city_orientation: int,
+    def _get_cells_in_city(self, center: IntVector2D, radius: int, city_orientation: int,
                            vector_field: IntVector2DArray) -> IntVector2DArray:
         """
         Function the collect cells of a city. It also populates the vector field accoring to the orientation of the
@@ -1205,6 +1226,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
             vector_field[cell] = align_cell_to_city(center, city_orientation, cell)
         return city_cells
 
+    @staticmethod
     def _are_cities_overlapping(center_1, center_2, radius):
         """
         Check if two cities overlap. That is we check if two squares with certain edge length and position overlap
@@ -1224,4 +1246,3 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
         """
         return np.abs(center_1[0] - center_2[0]) < radius and np.abs(center_1[1] - center_2[1]) < radius
 
-    return generator
diff --git a/flatland/envs/sparse_rail_gen.py b/flatland/envs/sparse_rail_gen.py
new file mode 100644
index 0000000000000000000000000000000000000000..e001754bc2942a9c0fd6ac4e96c7cf85c780b754
--- /dev/null
+++ b/flatland/envs/sparse_rail_gen.py
@@ -0,0 +1,21 @@
+"""Rail generators (infrastructure manager, "Infrastrukturbetreiber")."""
+import sys
+import warnings
+from typing import Callable, Tuple, Optional, Dict, List
+
+import numpy as np
+from numpy.random.mtrand import RandomState
+
+from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.core.grid.grid4_utils import get_direction, mirror, direction_to_point
+from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
+from flatland.core.grid.grid_utils import distance_on_rail, IntVector2DArray, IntVector2D, \
+    Vec2dOperations
+from flatland.core.grid.rail_env_grid import RailEnvTransitions
+from flatland.core.transition_map import GridTransitionMap
+from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map, connect_straight_line_in_grid_map, \
+    fix_inner_nodes, align_cell_to_city
+from flatland.envs import persistence
+
+from flatland.envs.rail_generators import RailGeneratorProduct, RailGenerator
+
diff --git a/tests/test_eval_timeout.py b/tests/test_eval_timeout.py
index 422f04bd8b3bd6bd8e43daea34f9a36488b7958c..dfc406e3b9d091fc8e9a477ea86fae025e7b1936 100644
--- a/tests/test_eval_timeout.py
+++ b/tests/test_eval_timeout.py
@@ -59,7 +59,7 @@ def my_controller(obs, number_of_agents):
     return _action
 
 
-def test_random_timeouts():
+def __disabled__test_random_timeouts():
     remote_client = FlatlandRemoteClient(verbose=False)
 
     my_observation_builder = CustomObservationBuilder()