diff --git a/examples/temporary_example.py b/examples/temporary_example.py
index 6bed439cb21f611353763da36a890739c77e5866..28a94db3754fc21783447cf15057022476189a63 100644
--- a/examples/temporary_example.py
+++ b/examples/temporary_example.py
@@ -54,8 +54,6 @@ env.agents_direction[0] = 1
 # TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too!
 env.obs_builder.reset()
 """
-
-
 """
 # INFINITE-LOOP TEST
 specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
@@ -84,6 +82,7 @@ env.obs_builder.reset()
 env = RailEnv(width=7,
               height=7,
               rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
+              # rail_generator=complex_rail_generator(nr_start_goal=2),
               number_of_agents=2)
 
 # Print the distance map of each cell to the target of the first agent
diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py
index 382bdd7d9c483aa491f9c11b06b62b13b45490a5..ac29408051efd4bc85f8120b02f7dfd59b14dee9 100644
--- a/flatland/envs/env_utils.py
+++ b/flatland/envs/env_utils.py
@@ -5,7 +5,7 @@ Definition of the RailEnv environment and related level-generation functions.
 Generator functions are functions that take width, height and num_resets as arguments and return
 a GridTransitionMap object.
 """
-# import numpy as np
+import numpy as np
 
 # from flatland.core.env import Environment
 # from flatland.core.env_observation_builder import TreeObsForRailEnv
@@ -271,3 +271,93 @@ def connect_rail(rail_trans, rail_array, start, end):
 
 def distance_on_rail(pos1, pos2):
     return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
+
+
+def get_new_position(position, movement):
+    if movement == 0:  # NORTH
+        return (position[0] - 1, position[1])
+    elif movement == 1:  # EAST
+        return (position[0], position[1] + 1)
+    elif movement == 2:  # SOUTH
+        return (position[0] + 1, position[1])
+    elif movement == 3:  # WEST
+        return (position[0], position[1] - 1)
+
+
+def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
+    """
+    Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target).
+
+    TODO: add extensive documentation, as users may need this function to simplify their custom level generators.
+    """
+
+    def _path_exists(rail, start, direction, end):
+        # BFS - Check if a path exists between the 2 nodes
+
+        visited = set()
+        stack = [(start, direction)]
+        while stack:
+            node = stack.pop()
+            if node[0][0] == end[0] and node[0][1] == end[1]:
+                return 1
+            if node not in visited:
+                visited.add(node)
+                moves = rail.get_transitions((node[0][0], node[0][1], node[1]))
+                for move_index in range(4):
+                    if moves[move_index]:
+                        stack.append((get_new_position(node[0], move_index),
+                                      move_index))
+
+                # If cell is a dead-end, append previous node with reversed
+                # orientation!
+                nbits = 0
+                tmp = rail.get_transitions((node[0][0], node[0][1]))
+                while tmp > 0:
+                    nbits += (tmp & 1)
+                    tmp = tmp >> 1
+                if nbits == 1:
+                    stack.append((node[0], (node[1] + 2) % 4))
+
+        return 0
+
+    valid_positions = []
+    for r in range(rail.height):
+        for c in range(rail.width):
+            if rail.get_transitions((r, c)) > 0:
+                valid_positions.append((r, c))
+
+    re_generate = True
+    while re_generate:
+        agents_position = [
+            valid_positions[i] for i in
+            np.random.choice(len(valid_positions), num_agents)]
+        agents_target = [
+            valid_positions[i] for i in
+            np.random.choice(len(valid_positions), num_agents)]
+
+        # agents_direction must be a direction for which a solution is
+        # guaranteed.
+        agents_direction = [0] * num_agents
+        re_generate = False
+        for i in range(num_agents):
+            valid_movements = []
+            for direction in range(4):
+                position = agents_position[i]
+                moves = rail.get_transitions((position[0], position[1], direction))
+                for move_index in range(4):
+                    if moves[move_index]:
+                        valid_movements.append((direction, move_index))
+
+            valid_starting_directions = []
+            for m in valid_movements:
+                new_position = get_new_position(agents_position[i], m[1])
+                if m[0] not in valid_starting_directions and \
+                   _path_exists(rail, new_position, m[0], agents_target[i]):
+                    valid_starting_directions.append(m[0])
+
+            if len(valid_starting_directions) == 0:
+                re_generate = True
+            else:
+                agents_direction[i] = valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]]
+
+    return agents_position, agents_direction, agents_target
diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 021c63f31a958790c04e17f49c7a1da57777cb9c..2c26076dd527c9e315366471274b20f021502881 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -5,7 +5,7 @@ import numpy as np
 
 from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
-from flatland.envs.env_utils import distance_on_rail, connect_rail
+from flatland.envs.env_utils import distance_on_rail, connect_rail, get_rnd_agents_pos_tgt_dir_on_rail
 
 
 def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
@@ -23,7 +23,7 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
         The matrix with the correct 16-bit bitmaps for each cell.
     """
 
-    def generator(width, height, num_resets=0):
+    def generator(width, height, agents_handles, num_resets=0):
         rail_trans = RailEnvTransitions()
         rail_array = np.zeros(shape=(width, height), dtype=np.uint16)
 
@@ -106,8 +106,12 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
 
         return_rail = GridTransitionMap(width=width, height=height, transitions=rail_trans)
         return_rail.grid = rail_array
-        # TODO: return start_goal
-        return return_rail
+
+        # TODO: return agents_position, agents_direction and agents_target!
+        # NOTE: the initial direction must be such that the target can be reached.
+        # See env_utils.get_rnd_agents_pos_tgt_dir_on_rail() for hints, if required.
+
+        return return_rail, [], [], []
 
     return generator
 
@@ -132,7 +136,7 @@ def rail_from_manual_specifications_generator(rail_spec):
         the matrix of correct 16-bit bitmaps for each cell.
     """
 
-    def generator(width, height, num_resets=0):
+    def generator(width, height, agents_handles, num_resets=0):
         t_utils = RailEnvTransitions()
 
         height = len(rail_spec)
@@ -147,7 +151,11 @@ def rail_from_manual_specifications_generator(rail_spec):
                     return []
                 rail.set_transitions((r, c), t_utils.rotate_transition(t_utils.transitions[cell[0]], cell[1]))
 
-        return rail
+        agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
+            rail,
+            len(agents_handles))
+
+        return rail, agents_position, agents_direction, agents_target
 
     return generator
 
@@ -168,8 +176,12 @@ def rail_from_GridTransitionMap_generator(rail_map):
         Generator function that always returns the given `rail_map' object.
     """
 
-    def generator(width, height, num_resets=0):
-        return rail_map
+    def generator(width, height, agents_handles, num_resets=0):
+        agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
+            rail_map,
+            len(agents_handles))
+
+        return rail_map, agents_position, agents_direction, agents_target
 
     return generator
 
@@ -189,7 +201,7 @@ def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames):
         Generator function that always returns the given `rail_map' object.
     """
 
-    def generator(width, height, num_resets=0):
+    def generator(width, height, agents_handles, num_resets=0):
         t_utils = RailEnvTransitions()
         rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils)
         rail_map.load_transition_map(list_of_filenames[num_resets % len(list_of_filenames)], override_gridsize=False)
@@ -197,7 +209,11 @@ def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames):
         if rail_map.grid.dtype == np.uint64:
             rail_map.transitions = Grid8Transitions()
 
-        return rail_map
+        agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
+            rail_map,
+            len(agents_handles))
+
+        return rail_map, agents_position, agents_direction, agents_target
 
     return generator
 
@@ -243,7 +259,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
         The matrix with the correct 16-bit bitmaps for each cell.
     """
 
-    def generator(width, height, num_resets=0):
+    def generator(width, height, agents_handles, num_resets=0):
         t_utils = RailEnvTransitions()
 
         transition_probability = cell_type_relative_proportion
@@ -472,6 +488,11 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
 
         return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
         return_rail.grid = tmp_rail
-        return return_rail
+
+        agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
+            return_rail,
+            len(agents_handles))
+
+        return return_rail, agents_position, agents_direction, agents_target
 
     return generator
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index cbb11c2db70dece00199bfae6bb051504f64a9f2..98abf81f469e1ea329db39e86c3cfe0a7756df28 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -9,6 +9,7 @@ import numpy as np
 from flatland.core.env import Environment
 from flatland.core.env_observation_builder import TreeObsForRailEnv
 from flatland.envs.generators import random_rail_generator
+from flatland.envs.env_utils import get_new_position
 
 # from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
 # from flatland.core.transition_map import GridTransitionMap
@@ -74,9 +75,10 @@ class RailEnv(Environment):
         Parameters
         -------
         rail_generator : function
-            The rail_generator function is a function that takes the width and
-            height of a  rail map along with the number of times the env has
-            been reset, and returns a GridTransitionMap object.
+            The rail_generator function is a function that takes the width,
+            height and agents handles of a  rail environment, along with the number of times
+            the env has been reset, and returns a GridTransitionMap object and a list of
+            starting positions, targets, and initial orientations for agent handle.
             Implemented functions are:
                 random_rail_generator : generate a random rail of given size
                 rail_from_GridTransitionMap_generator(rail_map) : generate a rail from
@@ -130,107 +132,16 @@ class RailEnv(Environment):
     def get_agent_handles(self):
         return self.agents_handles
 
-    def fill_valid_positions(self):
-        ''' Populate the valid_positions list for the current TransitionMap.
-        '''
-        self.valid_positions = valid_positions = []
-        for r in range(self.height):
-            for c in range(self.width):
-                if self.rail.get_transitions((r, c)) > 0:
-                    valid_positions.append((r, c))
-
-    def check_agent_lists(self):
-        ''' Check that the agent_handles, position and direction lists are all of length
-            number_of_agents.
-            (Suggest this is replaced with a single list of Agent objects :)
-        '''
-        for lAgents, name in zip(
-                [self.agents_handles, self.agents_position, self.agents_direction],
-                ["handles", "positions", "directions"]):
-            assert self.number_of_agents == len(lAgents), "Inconsistent agent list:" + name
-
-    def check_agent_locdirpath(self, iAgent):
-        ''' Check that agent iAgent has a valid location and direction,
-            with a path to its target.
-            (Not currently used?)
-        '''
-        valid_movements = []
-        for direction in range(4):
-            position = self.agents_position[iAgent]
-            moves = self.rail.get_transitions((position[0], position[1], direction))
-            for move_index in range(4):
-                if moves[move_index]:
-                    valid_movements.append((direction, move_index))
-
-        valid_starting_directions = []
-        for m in valid_movements:
-            new_position = self._new_position(self.agents_position[iAgent], m[1])
-            if m[0] not in valid_starting_directions and \
-                    self._path_exists(new_position, m[0], self.agents_target[iAgent]):
-                valid_starting_directions.append(m[0])
-
-        if len(valid_starting_directions) == 0:
-            return False
-        else:
-            return True
-
-    def pick_agent_direction(self, rcPos, rcTarget):
-        """ Pick and return a valid direction index (0..3) for an agent starting at
-            row,col rcPos with target rcTarget.
-            Return None if no path exists.
-            Picks random direction if more than one exists (uniformly).
+    def reset(self, regen_rail=True, replace_agents=True):
         """
-        valid_movements = []
-        for direction in range(4):
-            moves = self.rail.get_transitions((*rcPos, direction))
-            for move_index in range(4):
-                if moves[move_index]:
-                    valid_movements.append((direction, move_index))
-        # print("pos", rcPos, "targ", rcTarget, "valid movements", valid_movements)
-
-        valid_starting_directions = []
-        for m in valid_movements:
-            new_position = self._new_position(rcPos, m[1])
-            if m[0] not in valid_starting_directions and self._path_exists(new_position, m[0], rcTarget):
-                valid_starting_directions.append(m[0])
-
-        if len(valid_starting_directions) == 0:
-            return None
-        else:
-            return valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]]
-
-    def add_agent(self, rcPos=None, rcTarget=None, iDir=None):
-        """ Add a new agent at position rcPos with target rcTarget and
-            initial direction index iDir.
-            Should also store this initial position etc as environment "meta-data"
-            but this does not yet exist.
+        TODO: replace_agents is ignored at the moment; agents will always be replaced.
         """
-        self.check_agent_lists()
-
-        if rcPos is None:
-            rcPos = np.random.choice(len(self.valid_positions))
-
-        iAgent = self.number_of_agents
-
-        if iDir is None:
-            iDir = self.pick_agent_direction(rcPos, rcTarget)
-        if iDir is None:
-            print("Error picking agent direction at pos:", rcPos)
-            return None
-
-        self.agents_position.append(tuple(rcPos))  # ensure it's a tuple not a list
-        self.agents_handles.append(max(self.agents_handles + [-1]) + 1)  # max(handles) + 1, starting at 0
-        self.agents_direction.append(iDir)
-        self.agents_target.append(rcPos)  # set the target to the origin initially
-        self.number_of_agents += 1
-        self.check_agent_lists()
-        return iAgent
-
-    def reset(self, regen_rail=True, replace_agents=True):
         if regen_rail or self.rail is None:
-            # TODO: Import not only rail information but also start and goal positions
-            self.rail = self.rail_generator(self.width, self.height, self.num_resets)
-            self.fill_valid_positions()
+            self.rail, self.agents_position, self.agents_direction, self.agents_target = self.rail_generator(
+                self.width,
+                self.height,
+                self.agents_handles,
+                self.num_resets)
 
         self.num_resets += 1
 
@@ -238,37 +149,6 @@ class RailEnv(Environment):
         for handle in self.agents_handles:
             self.dones[handle] = False
 
-        # Use a TreeObsForRailEnv to compute distance maps to each agent's target, to sample initial
-        # agent's orientations that allow a valid solution.
-        # TODO: Possibility ot fill valid positions from list of goals and start
-        self.fill_valid_positions()
-
-        if replace_agents:
-            re_generate = True
-            while re_generate:
-
-                # self.agents_position = random.sample(valid_positions,
-                #                                     self.number_of_agents)
-                self.agents_position = [
-                    self.valid_positions[i] for i in
-                    np.random.choice(len(self.valid_positions), self.number_of_agents)]
-                self.agents_target = [
-                    self.valid_positions[i] for i in
-                    np.random.choice(len(self.valid_positions), self.number_of_agents)]
-
-                # agents_direction must be a direction for which a solution is
-                # guaranteed.
-                self.agents_direction = [0] * self.number_of_agents
-                re_generate = False
-
-                for i in range(self.number_of_agents):
-                    direction = self.pick_agent_direction(self.agents_position[i], self.agents_target[i])
-                    if direction is None:
-                        re_generate = True
-                        break
-                    else:
-                        self.agents_direction[i] = direction
-
         # Reset the state of the observation builder with the new environment
         self.obs_builder.reset()
 
@@ -342,7 +222,7 @@ class RailEnv(Environment):
                         movement = np.argmax(possible_transitions)
                         transition_isValid = True
 
-                new_position = self._new_position(pos, movement)
+                new_position = get_new_position(pos, movement)
                 # Is it a legal move?  1) transition allows the movement in the
                 # cell,  2) the new cell is not empty (case 0),  3) the cell is
                 # free, i.e., no agent is currently in that cell
@@ -401,45 +281,6 @@ class RailEnv(Environment):
         self.actions = [0] * self.number_of_agents
         return self._get_observations(), self.rewards_dict, self.dones, {}
 
-    def _new_position(self, position, movement):
-        if movement == 0:  # NORTH
-            return (position[0] - 1, position[1])
-        elif movement == 1:  # EAST
-            return (position[0], position[1] + 1)
-        elif movement == 2:  # SOUTH
-            return (position[0] + 1, position[1])
-        elif movement == 3:  # WEST
-            return (position[0], position[1] - 1)
-
-    def _path_exists(self, start, direction, end):
-        # BFS - Check if a path exists between the 2 nodes
-
-        visited = set()
-        stack = [(start, direction)]
-        while stack:
-            node = stack.pop()
-            if node[0][0] == end[0] and node[0][1] == end[1]:
-                return 1
-            if node not in visited:
-                visited.add(node)
-                moves = self.rail.get_transitions((node[0][0], node[0][1], node[1]))
-                for move_index in range(4):
-                    if moves[move_index]:
-                        stack.append((self._new_position(node[0], move_index),
-                                      move_index))
-
-                # If cell is a dead-end, append previous node with reversed
-                # orientation!
-                nbits = 0
-                tmp = self.rail.get_transitions((node[0][0], node[0][1]))
-                while tmp > 0:
-                    nbits += (tmp & 1)
-                    tmp = tmp >> 1
-                if nbits == 1:
-                    stack.append((node[0], (node[1] + 2) % 4))
-
-        return 0
-
     def _get_observations(self):
         self.obs_dict = {}
         for handle in self.agents_handles: