diff --git a/examples/temporary_example.py b/examples/temporary_example.py
index 2ea68cfde0d555b1bfc4b021cec7c325f122bc0b..8b53b1ddb7e3558d8a712e14d768245c56032d9a 100644
--- a/examples/temporary_example.py
+++ b/examples/temporary_example.py
@@ -9,32 +9,36 @@ from flatland.utils.rendertools import *
 random.seed(1)
 np.random.seed(1)
 
-
 # Example generate a random rail
-rail = generate_random_rail(20, 20)
-
-env = RailEnv(rail, number_of_agents=10)
+env = RailEnv(width=20, height=20, rail_generator=generate_random_rail, number_of_agents=10)
 env.reset()
 
 env_renderer = RenderTool(env)
 env_renderer.renderEnv(show=True)
 
-
 # Example generate a rail given a manual specification,
 # a map of tuples (cell_type, rotation)
 specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
          [(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)]]
 
-rail = generate_rail_from_manual_specifications(specs)
-env = RailEnv(rail, number_of_agents=1)
+env = RailEnv(width=6,
+              height=2,
+              rail_generator=generate_rail_from_manual_specifications(specs),
+              number_of_agents=1)
 
 handle = env.get_agent_handles()
 
-env.reset()
+obs = env.reset()
 
 env.agents_position = [[1, 4]]
 env.agents_target = [[1, 1]]
 env.agents_direction = [1]
+# TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too!
+env.obs_builder.reset()
+
+# TODO: delete next line
+#print(env.obs_builder.distance_map[0,:,:])
+#print(env.obs_builder.max_dist)
 
 env_renderer = RenderTool(env)
 env_renderer.renderEnv(show=True)
diff --git a/flatland/core/env.py b/flatland/core/env.py
index a7e63fd4e2697996a4dbe0a684735fd46153fb1b..02d912a3aba8c6ee84a2159b248e0631e194d2de 100644
--- a/flatland/core/env.py
+++ b/flatland/core/env.py
@@ -6,6 +6,7 @@ The base Environment class is adapted from rllib.env.MultiAgentEnv
 import random
 
 from .env_observation_builder import TreeObsForRailEnv
+from flatland.utils.rail_env_generator import generate_random_rail
 
 
 class Environment:
@@ -121,35 +122,56 @@ class RailEnv:
     """
 
     def __init__(self,
-                 rail,
+                 width,
+                 height,
+                 rail_generator=generate_random_rail,
                  number_of_agents=1,
-                 custom_observation_builder=TreeObsForRailEnv):
+                 obs_builder_object=TreeObsForRailEnv(max_depth=2)):
         """
         Environment init.
 
         Parameters
         -------
-        rail : numpy.ndarray of type numpy.uint16
-            The transition matrix that defines the environment.
+        rail_generator : function
+            The rail_generator function is a function that takes the width and
+            height of a  rail map along with the number of times the env has
+            been reset, and returns a GridTransitionMap object.
+            Implemented functions are:
+                generate_random_rail : generate a random rail of given size
+                TODO: generate_rail_from_saved_list ---
+        width : int
+            The width of the rail map. Potentially in the future,
+            a range of widths to sample from.
+        height : int
+            The height of the rail map. Potentially in the future,
+            a range of heights to sample from.
         number_of_agents : int
-            Number of agents to spawn on the map.
-        custom_observation_builder: ObservationBuilder object
-            ObservationBuilder-derived object that takes this env object
-            as input as provides observation vectors for each agent.
+            Number of agents to spawn on the map. Potentially in the future,
+            a range of number of agents to sample from.
+        obs_builder_object: ObservationBuilder object
+            ObservationBuilder-derived object that takes builds observation
+            vectors for each agent.
         """
 
-        self.rail = rail
-        self.width = rail.width
-        self.height = rail.height
+        self.rail_generator = rail_generator
+        self.num_resets = 0
+        self.rail = None
+        self.width = width
+        self.height = height
 
         self.number_of_agents = number_of_agents
 
-        self.obs_builder = custom_observation_builder(env=self)
+        self.obs_builder = obs_builder_object
+        self.obs_builder.set_env(self)
 
         self.actions = [0]*self.number_of_agents
         self.rewards = [0]*self.number_of_agents
         self.done = False
 
+        self.agents_position = []
+        self.agents_target = []
+        self.agents_direction = []
+
         self.dones = {"__all__": False}
         self.obs_dict = {}
         self.rewards_dict = {}
@@ -160,6 +182,9 @@ class RailEnv:
         return self.agents_handles
 
     def reset(self):
+        self.rail = self.rail_generator(self.width, self.height, self.num_resets)
+        self.num_resets += 1
+
         self.dones = {"__all__": False}
         for handle in self.agents_handles:
             self.dones[handle] = False
diff --git a/flatland/utils/rail_env_generator.py b/flatland/utils/rail_env_generator.py
index 69e5b831be67c6a808e22b5413255789acf90f27..ab11ac5e7df8281294c02fd4ef8192519b061e4d 100644
--- a/flatland/utils/rail_env_generator.py
+++ b/flatland/utils/rail_env_generator.py
@@ -24,28 +24,62 @@ def generate_rail_from_manual_specifications(rail_spec):
 
     Returns
     -------
-    numpy.ndarray of type numpy.uint16
-        The matrix with the correct 16-bit bitmaps for each cell.
+    function
+        Generator function that always returns a GridTransitionMap object with
+        the matrix of correct 16-bit bitmaps for each cell.
     """
-    t_utils = RailEnvTransitions()
+    def generator(width, height, num_resets=0):
+        t_utils = RailEnvTransitions()
 
-    height = len(rail_spec)
-    width = len(rail_spec[0])
-    rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
+        height = len(rail_spec)
+        width = len(rail_spec[0])
+        rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
 
-    for r in range(height):
-        for c in range(width):
-            cell = rail_spec[r][c]
-            if cell[0] < 0 or cell[0] >= len(t_utils.transitions):
-                print("ERROR - invalid cell type=", cell[0])
-                return []
-            rail.set_transitions((r, c), t_utils.rotate_transition(
-                          t_utils.transitions[cell[0]], cell[1]))
+        for r in range(height):
+            for c in range(width):
+                cell = rail_spec[r][c]
+                if cell[0] < 0 or cell[0] >= len(t_utils.transitions):
+                    print("ERROR - invalid cell type=", cell[0])
+                    return []
+                rail.set_transitions((r, c), t_utils.rotate_transition(
+                              t_utils.transitions[cell[0]], cell[1]))
 
-    return rail
+        return rail
+
+    return generator
+
+
+def generate_rail_from_GridTransitionMap(rail_map):
+    """
+    Utility to convert a rail given by a GridTransitionMap map with the correct
+    16-bit transitions specifications.
+
+    Parameters
+    -------
+    rail_map : GridTransitionMap object
+        GridTransitionMap object to return when the generator is called.
+
+    Returns
+    -------
+    function
+        Generator function that always returns the given `rail_map' object.
+    """
+    def generator(width, height, num_resets=0):
+        return rail_map
+
+    return generator
+
+
+"""
+def generate_rail_from_list_of_manual_specifications(list_of_specifications)
+    def generator(width, height, num_resets=0):
+        return generate_rail_from_manual_specifications(list_of_specifications)
+
+    return generator
+"""
 
 
-def generate_random_rail(width, height):
+def generate_random_rail(width, height, num_resets=0):
     """
     Dummy random level generator:
     - fill in cells at random in [width-2, height-2]
diff --git a/tests/test_environments.py b/tests/test_environments.py
index 03544b08196a609bc6d4ed92c393f0a72bfbab8c..ce9fbd4f010413140ef09a843db0c7657005524b 100644
--- a/tests/test_environments.py
+++ b/tests/test_environments.py
@@ -4,6 +4,7 @@
 from flatland.core.env import RailEnv
 from flatland.core.transitions import Grid4Transitions
 from flatland.core.transition_map import GridTransitionMap
+from flatland.utils.rail_env_generator import generate_rail_from_GridTransitionMap
 import numpy as np
 
 """Tests for `flatland` package."""
@@ -46,7 +47,7 @@ def test_rail_environment_single_agent():
 
     rail = GridTransitionMap(width=3, height=3, transitions=transitions)
     rail.grid = rail_map
-    rail_env = RailEnv(rail, number_of_agents=1)
+    rail_env = RailEnv(width=3, height=3, rail_generator=generate_rail_from_GridTransitionMap(rail), number_of_agents=1)
     for _ in range(200):
         _ = rail_env.reset()
 
@@ -118,7 +119,10 @@ def test_dead_end():
                              transitions=transitions)
 
     rail.grid = rail_map
-    rail_env = RailEnv(rail, number_of_agents=1)
+    rail_env = RailEnv(width=rail_map.shape[1],
+                       height=rail_map.shape[0],
+                       rail_generator=generate_rail_from_GridTransitionMap(rail),
+                       number_of_agents=1)
 
     def check_consistency(rail_env):
         # We run step to check that trains do not move anymore
@@ -164,7 +168,10 @@ def test_dead_end():
                              transitions=transitions)
 
     rail.grid = rail_map
-    rail_env = RailEnv(rail, number_of_agents=1)
+    rail_env = RailEnv(width=rail_map.shape[1],
+                       height=rail_map.shape[0],
+                       rail_generator=generate_rail_from_GridTransitionMap(rail),
+                       number_of_agents=1)
 
     rail_env.reset()
     rail_env.agents_target[0] = [0, 0]
@@ -177,11 +184,3 @@ def test_dead_end():
     rail_env.agents_position[0] = [2, 0]
     rail_env.agents_direction[0] = 0
     check_consistency(rail_env)
-
-
-
-
-
-
-test_dead_end()
-
diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py
index ae9a9e1867a9e9d65877e81e73370550a233d206..5fecd085c646e8a19e7be553eea258105865a8f4 100644
--- a/tests/test_rendertools.py
+++ b/tests/test_rendertools.py
@@ -37,8 +37,7 @@ def checkFrozenImage(sFileImage):
 
 def test_render_env():
     random.seed(100)
-    oRail = rail_env_generator.generate_random_rail(10, 10)
-    oEnv = RailEnv(oRail, number_of_agents=2)
+    oEnv = RailEnv(width=10, height=10, rail_generator=rail_env_generator.generate_random_rail, number_of_agents=2)
     oEnv.reset()
     oRT = rt.RenderTool(oEnv)
     plt.figure(figsize=(10, 10))