diff --git a/examples/temporary_example.py b/examples/temporary_example.py
index 03b5ebd001c25d70c4630db4300816f3a8ddd500..52927160f467173be152d874d1e5a5f00d8eb474 100644
--- a/examples/temporary_example.py
+++ b/examples/temporary_example.py
@@ -10,13 +10,34 @@ random.seed(1)
 np.random.seed(1)
 
 """
+transition_probability = [1.0,  # empty cell - Case 0
+                          3.0,  # Case 1 - straight
+                          1.0,  # Case 2 - simple switch
+                          3.0,  # Case 3 - diamond drossing
+                          2.0,  # Case 4 - single slip
+                          1.0,  # Case 5 - double slip
+                          1.0,  # Case 6 - symmetrical
+                          1.0]  # Case 7 - dead end
+"""
+transition_probability = [1.0,  # empty cell - Case 0
+                          1.0,  # Case 1 - straight
+                          1.0,  # Case 2 - simple switch
+                          1.0,  # Case 3 - diamond drossing
+                          1.0,  # Case 4 - single slip
+                          1.0,  # Case 5 - double slip
+                          1.0,  # Case 6 - symmetrical
+                          1.0]  # Case 7 - dead end
+
 # Example generate a random rail
-env = RailEnv(width=20, height=20, rail_generator=random_rail_generator, number_of_agents=10)
+env = RailEnv(width=20,
+              height=20,
+              rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
+              number_of_agents=10)
 env.reset()
 
 env_renderer = RenderTool(env)
 env_renderer.renderEnv(show=True)
-"""
+
 
 # Example generate a rail given a manual specification,
 # a map of tuples (cell_type, rotation)
@@ -26,7 +47,7 @@ specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
 env = RailEnv(width=6,
               height=2,
               rail_generator=rail_from_manual_specifications_generator(specs),
-              number_of_agents=2,
+              number_of_agents=1,
               obs_builder_object=TreeObsForRailEnv(max_depth=2))
 
 handle = env.get_agent_handles()
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index 24f213abd3daa1c45a4363058dcbf3baf7794f24..cd8a53094471a8e48e158c301f45d148eebdecce 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -72,7 +72,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         # Update local lookup table for all agents' target locations
         self.location_has_target = {}
         for loc in self.env.agents_target:
-            self.location_has_target[(loc[0],loc[1])] = 1
+            self.location_has_target[(loc[0], loc[1])] = 1
 
     def _distance_map_walker(self, position, target_nr):
         """
@@ -292,8 +292,6 @@ class TreeObsForRailEnv(ObservationBuilder):
             if position in self.location_has_target:
                 other_target_encountered = True
 
-
-
             # #############################
             # #############################
 
@@ -354,10 +352,8 @@ class TreeObsForRailEnv(ObservationBuilder):
                            0,
                            self.distance_map[handle, position[0], position[1], direction]]
 
-
         # TODO:
 
-
         # #############################
         # #############################
 
@@ -368,9 +364,9 @@ class TreeObsForRailEnv(ObservationBuilder):
                                                                (branch_direction+2) % 4):
                 # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
                 # it back
-                new_cell = self._new_position(position, (branch_direction+2)%4)
+                new_cell = self._new_position(position, (branch_direction+2) % 4)
 
-                branch_observation = self._explore_branch(handle, new_cell, (branch_direction+2)%4, depth+1)
+                branch_observation = self._explore_branch(handle, new_cell, (branch_direction+2) % 4, depth+1)
                 observation = observation + branch_observation
 
             elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction),
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 8b34bf3d4ea51dfd1be57c70297b45bedc7906c3..1664c3155c61ad2ddde8244440edf30d48f3e410 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -84,7 +84,7 @@ def generate_rail_from_list_of_manual_specifications(list_of_specifications)
 """
 
 
-def random_rail_generator(width, height, num_resets=0):
+def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
     """
     Dummy random level generator:
     - fill in cells at random in [width-2, height-2]
@@ -116,234 +116,248 @@ def random_rail_generator(width, height, num_resets=0):
         The matrix with the correct 16-bit bitmaps for each cell.
     """
 
-    t_utils = RailEnvTransitions()
-
-    transitions_templates_ = []
-    for i in range(len(t_utils.transitions)-1):  # don't include dead-ends
-        all_transitions = 0
-        for dir_ in range(4):
-            trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
-            all_transitions |= (trans[0] << 3) | \
-                               (trans[1] << 2) | \
-                               (trans[2] << 1) | \
-                               (trans[3])
-
-        template = [int(x) for x in bin(all_transitions)[2:]]
-        template = [0]*(4-len(template)) + template
-
-        # add all rotations
-        for rot in [0, 90, 180, 270]:
-            transitions_templates_.append((template,
-                                          t_utils.rotate_transition(
-                                           t_utils.transitions[i],
-                                           rot)))
-            template = [template[-1]]+template[:-1]
-
-    def get_matching_templates(template):
-        ret = []
-        for i in range(len(transitions_templates_)):
-            is_match = True
-            for j in range(4):
-                if template[j] >= 0 and \
-                   template[j] != transitions_templates_[i][0][j]:
-                    is_match = False
-                    break
-            if is_match:
-                ret.append(transitions_templates_[i][1])
-        return ret
-
-    MAX_INSERTIONS = (width-2) * (height-2) * 10
-    MAX_ATTEMPTS_FROM_SCRATCH = 10
-
-    attempt_number = 0
-    while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH:
-        cells_to_fill = []
-        rail = []
-        for r in range(height):
-            rail.append([None]*width)
-            if r > 0 and r < height-1:
-                cells_to_fill = cells_to_fill \
-                                + [(r, c) for c in range(1, width-1)]
-
-        num_insertions = 0
-        while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0:
-            cell = random.sample(cells_to_fill, 1)[0]
-            cells_to_fill.remove(cell)
-            row = cell[0]
-            col = cell[1]
-
-            # look at its neighbors and see what are the possible transitions
-            # that can be chosen from, if any.
-            valid_template = [-1, -1, -1, -1]
-
-            for el in [(0, 2, (-1, 0)),
-                       (1, 3, (0, 1)),
-                       (2, 0, (1, 0)),
-                       (3, 1, (0, -1))]:  # N, E, S, W
-                neigh_trans = rail[row+el[2][0]][col+el[2][1]]
-                if neigh_trans is not None:
-                    # select transition coming from facing direction el[1] and
-                    # moving to direction el[1]
-                    max_bit = 0
-                    for k in range(4):
-                        max_bit |= \
-                         t_utils.get_transition(neigh_trans, k, el[1])
-
-                    if max_bit:
-                        valid_template[el[0]] = 1
-                    else:
-                        valid_template[el[0]] = 0
-
-            possible_cell_transitions = get_matching_templates(valid_template)
-
-            if len(possible_cell_transitions) == 0:  # NO VALID TRANSITIONS
-                # no cell can be filled in without violating some transitions
-                # can a dead-end solve the problem?
-                if valid_template.count(1) == 1:
-                    for k in range(4):
-                        if valid_template[k] == 1:
-                            rot = 0
-                            if k == 0:
-                                rot = 180
-                            elif k == 1:
-                                rot = 270
-                            elif k == 2:
+    def generator(width, height, num_resets=0):
+        t_utils = RailEnvTransitions()
+
+        transition_probability = cell_type_relative_proportion
+
+        transitions_templates_ = []
+        transition_probabilities = []
+        for i in range(len(t_utils.transitions)-1):  # don't include dead-ends
+            all_transitions = 0
+            for dir_ in range(4):
+                trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
+                all_transitions |= (trans[0] << 3) | \
+                                   (trans[1] << 2) | \
+                                   (trans[2] << 1) | \
+                                   (trans[3])
+
+            template = [int(x) for x in bin(all_transitions)[2:]]
+            template = [0]*(4-len(template)) + template
+
+            # add all rotations
+            for rot in [0, 90, 180, 270]:
+                transitions_templates_.append((template,
+                                              t_utils.rotate_transition(
+                                               t_utils.transitions[i],
+                                               rot)))
+                transition_probabilities.append(transition_probability[i])
+                template = [template[-1]]+template[:-1]
+
+        def get_matching_templates(template):
+            ret = []
+            for i in range(len(transitions_templates_)):
+                is_match = True
+                for j in range(4):
+                    if template[j] >= 0 and \
+                       template[j] != transitions_templates_[i][0][j]:
+                        is_match = False
+                        break
+                if is_match:
+                    ret.append((transitions_templates_[i][1], transition_probabilities[i]))
+            return ret
+
+        MAX_INSERTIONS = (width-2) * (height-2) * 10
+        MAX_ATTEMPTS_FROM_SCRATCH = 10
+
+        attempt_number = 0
+        while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH:
+            cells_to_fill = []
+            rail = []
+            for r in range(height):
+                rail.append([None]*width)
+                if r > 0 and r < height-1:
+                    cells_to_fill = cells_to_fill \
+                                    + [(r, c) for c in range(1, width-1)]
+
+            num_insertions = 0
+            while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0:
+                cell = random.sample(cells_to_fill, 1)[0]
+                cells_to_fill.remove(cell)
+                row = cell[0]
+                col = cell[1]
+
+                # look at its neighbors and see what are the possible transitions
+                # that can be chosen from, if any.
+                valid_template = [-1, -1, -1, -1]
+
+                for el in [(0, 2, (-1, 0)),
+                           (1, 3, (0, 1)),
+                           (2, 0, (1, 0)),
+                           (3, 1, (0, -1))]:  # N, E, S, W
+                    neigh_trans = rail[row+el[2][0]][col+el[2][1]]
+                    if neigh_trans is not None:
+                        # select transition coming from facing direction el[1] and
+                        # moving to direction el[1]
+                        max_bit = 0
+                        for k in range(4):
+                            max_bit |= \
+                             t_utils.get_transition(neigh_trans, k, el[1])
+
+                        if max_bit:
+                            valid_template[el[0]] = 1
+                        else:
+                            valid_template[el[0]] = 0
+
+                possible_cell_transitions = get_matching_templates(valid_template)
+
+                if len(possible_cell_transitions) == 0:  # NO VALID TRANSITIONS
+                    # no cell can be filled in without violating some transitions
+                    # can a dead-end solve the problem?
+                    if valid_template.count(1) == 1:
+                        for k in range(4):
+                            if valid_template[k] == 1:
                                 rot = 0
-                            elif k == 3:
-                                rot = 90
+                                if k == 0:
+                                    rot = 180
+                                elif k == 1:
+                                    rot = 270
+                                elif k == 2:
+                                    rot = 0
+                                elif k == 3:
+                                    rot = 90
 
-                            rail[row][col] = t_utils.rotate_transition(
-                                              int('0000000000100000', 2), rot)
+                                rail[row][col] = t_utils.rotate_transition(
+                                                  int('0000000000100000', 2), rot)
+                                num_insertions += 1
+
+                                break
+
+                    else:
+                        # can I get valid transitions by removing a single
+                        # neighboring cell?
+                        bestk = -1
+                        besttrans = []
+                        for k in range(4):
+                            tmp_template = valid_template[:]
+                            tmp_template[k] = -1
+                            possible_cell_transitions = get_matching_templates(
+                                                         tmp_template)
+                            if len(possible_cell_transitions) > len(besttrans):
+                                besttrans = possible_cell_transitions
+                                bestk = k
+
+                        if bestk >= 0:
+                            # Replace the corresponding cell with None, append it
+                            # to cells to fill, fill in a transition in the current
+                            # cell.
+                            replace_row = row - 1
+                            replace_col = col
+                            if bestk == 1:
+                                replace_row = row
+                                replace_col = col + 1
+                            elif bestk == 2:
+                                replace_row = row + 1
+                                replace_col = col
+                            elif bestk == 3:
+                                replace_row = row
+                                replace_col = col - 1
+
+                            cells_to_fill.append((replace_row, replace_col))
+                            rail[replace_row][replace_col] = None
+
+                            possible_transitions, possible_probabilities = zip(*besttrans)
+                            possible_probabilities = \
+                                np.exp(possible_probabilities) / sum(np.exp(possible_probabilities))
+
+                            rail[row][col] = np.random.choice(possible_transitions,
+                                                              p=possible_probabilities)
                             num_insertions += 1
 
-                            break
+                        else:
+                            print('WARNING: still nothing!')
+                            rail[row][col] = int('0000000000000000', 2)
+                            num_insertions += 1
+                            pass
 
                 else:
-                    # can I get valid transitions by removing a single
-                    # neighboring cell?
-                    bestk = -1
-                    besttrans = []
-                    for k in range(4):
-                        tmp_template = valid_template[:]
-                        tmp_template[k] = -1
-                        possible_cell_transitions = get_matching_templates(
-                                                     tmp_template)
-                        if len(possible_cell_transitions) > len(besttrans):
-                            besttrans = possible_cell_transitions
-                            bestk = k
-
-                    if bestk >= 0:
-                        # Replace the corresponding cell with None, append it
-                        # to cells to fill, fill in a transition in the current
-                        # cell.
-                        replace_row = row - 1
-                        replace_col = col
-                        if bestk == 1:
-                            replace_row = row
-                            replace_col = col + 1
-                        elif bestk == 2:
-                            replace_row = row + 1
-                            replace_col = col
-                        elif bestk == 3:
-                            replace_row = row
-                            replace_col = col - 1
+                    possible_transitions, possible_probabilities = zip(*possible_cell_transitions)
+                    possible_probabilities = np.exp(possible_probabilities) / sum(np.exp(possible_probabilities))
 
-                        cells_to_fill.append((replace_row, replace_col))
-                        rail[replace_row][replace_col] = None
+                    rail[row][col] = np.random.choice(possible_transitions,
+                                                      p=possible_probabilities)
+                    num_insertions += 1
 
-                        rail[row][col] = random.sample(
-                                                     besttrans, 1)[0]
-                        num_insertions += 1
+            if num_insertions == MAX_INSERTIONS:
+                # Failed to generate a valid level; try again for a number of times
+                attempt_number += 1
+            else:
+                break
 
-                    else:
-                        print('WARNING: still nothing!')
-                        rail[row][col] = int('0000000000000000', 2)
-                        num_insertions += 1
-                        pass
+        if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH:
+            print('ERROR: failed to generate level')
 
+        # Finally pad the border of the map with dead-ends to avoid border issues;
+        # at most 1 transition in the neigh cell
+        for r in range(height):
+            # Check for transitions coming from [r][1] to WEST
+            max_bit = 0
+            neigh_trans = rail[r][1]
+            if neigh_trans is not None:
+                for k in range(4):
+                    neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
+                                                 & (2**4-1)
+                    max_bit = max_bit | (neigh_trans_from_direction & 1)
+            if max_bit:
+                rail[r][0] = t_utils.rotate_transition(
+                               int('0000000000100000', 2), 270)
             else:
-                rail[row][col] = random.sample(
-                                             possible_cell_transitions, 1)[0]
-                num_insertions += 1
-
-        if num_insertions == MAX_INSERTIONS:
-            # Failed to generate a valid level; try again for a number of times
-            attempt_number += 1
-        else:
-            break
-
-    if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH:
-        print('ERROR: failed to generate level')
-
-    # Finally pad the border of the map with dead-ends to avoid border issues;
-    # at most 1 transition in the neigh cell
-    for r in range(height):
-        # Check for transitions coming from [r][1] to WEST
-        max_bit = 0
-        neigh_trans = rail[r][1]
-        if neigh_trans is not None:
-            for k in range(4):
-                neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
-                                             & (2**4-1)
-                max_bit = max_bit | (neigh_trans_from_direction & 1)
-        if max_bit:
-            rail[r][0] = t_utils.rotate_transition(
-                           int('0000000000100000', 2), 270)
-        else:
-            rail[r][0] = int('0000000000000000', 2)
-
-        # Check for transitions coming from [r][-2] to EAST
-        max_bit = 0
-        neigh_trans = rail[r][-2]
-        if neigh_trans is not None:
-            for k in range(4):
-                neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
-                                             & (2**4-1)
-                max_bit = max_bit | (neigh_trans_from_direction & (1 << 2))
-        if max_bit:
-            rail[r][-1] = t_utils.rotate_transition(int('0000000000100000', 2),
-                                                    90)
-        else:
-            rail[r][-1] = int('0000000000000000', 2)
-
-    for c in range(width):
-        # Check for transitions coming from [1][c] to NORTH
-        max_bit = 0
-        neigh_trans = rail[1][c]
-        if neigh_trans is not None:
-            for k in range(4):
-                neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
-                                              & (2**4-1)
-                max_bit = max_bit | (neigh_trans_from_direction & (1 << 3))
-        if max_bit:
-            rail[0][c] = int('0000000000100000', 2)
-        else:
-            rail[0][c] = int('0000000000000000', 2)
-
-        # Check for transitions coming from [-2][c] to SOUTH
-        max_bit = 0
-        neigh_trans = rail[-2][c]
-        if neigh_trans is not None:
-            for k in range(4):
-                neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
-                                             & (2**4-1)
-                max_bit = max_bit | (neigh_trans_from_direction & (1 << 1))
-        if max_bit:
-            rail[-1][c] = t_utils.rotate_transition(
-                            int('0000000000100000', 2), 180)
-        else:
-            rail[-1][c] = int('0000000000000000', 2)
-
-    # For display only, wrong levels
-    for r in range(height):
+                rail[r][0] = int('0000000000000000', 2)
+
+            # Check for transitions coming from [r][-2] to EAST
+            max_bit = 0
+            neigh_trans = rail[r][-2]
+            if neigh_trans is not None:
+                for k in range(4):
+                    neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
+                                                 & (2**4-1)
+                    max_bit = max_bit | (neigh_trans_from_direction & (1 << 2))
+            if max_bit:
+                rail[r][-1] = t_utils.rotate_transition(int('0000000000100000', 2),
+                                                        90)
+            else:
+                rail[r][-1] = int('0000000000000000', 2)
+
         for c in range(width):
-            if rail[r][c] is None:
-                rail[r][c] = int('0000000000000000', 2)
+            # Check for transitions coming from [1][c] to NORTH
+            max_bit = 0
+            neigh_trans = rail[1][c]
+            if neigh_trans is not None:
+                for k in range(4):
+                    neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
+                                                  & (2**4-1)
+                    max_bit = max_bit | (neigh_trans_from_direction & (1 << 3))
+            if max_bit:
+                rail[0][c] = int('0000000000100000', 2)
+            else:
+                rail[0][c] = int('0000000000000000', 2)
+
+            # Check for transitions coming from [-2][c] to SOUTH
+            max_bit = 0
+            neigh_trans = rail[-2][c]
+            if neigh_trans is not None:
+                for k in range(4):
+                    neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
+                                                 & (2**4-1)
+                    max_bit = max_bit | (neigh_trans_from_direction & (1 << 1))
+            if max_bit:
+                rail[-1][c] = t_utils.rotate_transition(
+                                int('0000000000100000', 2), 180)
+            else:
+                rail[-1][c] = int('0000000000000000', 2)
+
+        # For display only, wrong levels
+        for r in range(height):
+            for c in range(width):
+                if rail[r][c] is None:
+                    rail[r][c] = int('0000000000000000', 2)
+
+        tmp_rail = np.asarray(rail, dtype=np.uint16)
+        return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
+        return_rail.grid = tmp_rail
+        return return_rail
 
-    tmp_rail = np.asarray(rail, dtype=np.uint16)
-    return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
-    return_rail.grid = tmp_rail
-    return return_rail
+    return generator
 
 
 class RailEnv(Environment):