diff --git a/examples/play_model.py b/examples/play_model.py
index 1f654c15c590801fcfcaa518deb8289af9a95d99..e6e81c972c756ae66e4cea176681611e1a798b06 100644
--- a/examples/play_model.py
+++ b/examples/play_model.py
@@ -1,4 +1,5 @@
-from flatland.envs.rail_env import RailEnv, complex_rail_generator
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.generators import complex_rail_generator
 # from flatland.core.env_observation_builder import TreeObsForRailEnv
 from flatland.utils.rendertools import RenderTool
 from flatland.baselines.dueling_double_dqn import Agent
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 8969af3ef989b652fce96045a8fc407780e384c3..eb8786c2333308f73eaadd0c8817e23b248080a1 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -8,737 +8,10 @@ import numpy as np
 
 from flatland.core.env import Environment
 from flatland.core.env_observation_builder import TreeObsForRailEnv
+from flatland.envs.generators import random_rail_generator
 
-from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
-from flatland.core.transition_map import GridTransitionMap
-
-
-class AStarNode():
-    """A node class for A* Pathfinding"""
-
-    def __init__(self, parent=None, pos=None):
-        self.parent = parent
-        self.pos = pos
-        self.g = 0
-        self.h = 0
-        self.f = 0
-
-    def __eq__(self, other):
-        return self.pos == other.pos
-
-    def update_if_better(self, other):
-        if other.g < self.g:
-            self.parent = other.parent
-            self.g = other.g
-            self.h = other.h
-            self.f = other.f
-
-
-def get_direction(pos1, pos2):
-    """
-    Assumes pos1 and pos2 are adjacent location on grid.
-    Returns direction (int) that can be used with transitions.
-    """
-    diff_0 = pos2[0] - pos1[0]
-    diff_1 = pos2[1] - pos1[1]
-    if diff_0 < 0:
-        return 0
-    if diff_0 > 0:
-        return 2
-    if diff_1 > 0:
-        return 1
-    if diff_1 < 0:
-        return 3
-    return 0
-
-
-def mirror(dir):
-    return (dir + 2) % 4
-
-
-def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_pos, end_pos):
-    # start by getting direction used to get to current node
-    # and direction from current node to possible child node
-    new_dir = get_direction(current_pos, new_pos)
-    if prev_pos is not None:
-        current_dir = get_direction(prev_pos, current_pos)
-    else:
-        current_dir = new_dir
-    # create new transition that would go to child
-    new_trans = rail_array[current_pos]
-    if prev_pos is None:
-        if new_trans == 0:
-            # need to flip direction because of how end points are defined
-            new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
-        else:
-            # check if matches existing layout
-            new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
-            # new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
-            # rail_trans.print(new_trans)
-    else:
-        # set the forward path
-        new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
-        # set the backwards path
-        new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
-    if new_pos == end_pos:
-        # need to validate end pos setup as well
-        new_trans_e = rail_array[end_pos]
-        if new_trans_e == 0:
-            # need to flip direction because of how end points are defined
-            new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
-        else:
-            # check if matches existing layout
-            new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
-            # new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1)
-            # print("end:", end_pos, current_pos)
-            # rail_trans.print(new_trans_e)
-
-        # print("========> end trans")
-        # rail_trans.print(new_trans_e)
-        if not rail_trans.is_valid(new_trans_e):
-            # print("end failed", end_pos, current_pos)
-            return False
-        # else:
-        #    print("end ok!", end_pos, current_pos)
-
-    # is transition is valid?
-    # print("=======> trans")
-    # rail_trans.print(new_trans)
-    return rail_trans.is_valid(new_trans)
-
-
-def a_star(rail_trans, rail_array, start, end):
-    """
-    Returns a list of tuples as a path from the given start to end.
-    If no path is found, returns path to closest point to end.
-    """
-    rail_shape = rail_array.shape
-    start_node = AStarNode(None, start)
-    end_node = AStarNode(None, end)
-    open_list = []
-    closed_list = []
-
-    open_list.append(start_node)
-
-    # this could be optimized
-    def is_node_in_list(node, the_list):
-        for o_node in the_list:
-            if node == o_node:
-                return o_node
-        return None
-
-    while len(open_list) > 0:
-        # get node with current shortest est. path (lowest f)
-        current_node = open_list[0]
-        current_index = 0
-        for index, item in enumerate(open_list):
-            if item.f < current_node.f:
-                current_node = item
-                current_index = index
-
-        # pop current off open list, add to closed list
-        open_list.pop(current_index)
-        closed_list.append(current_node)
-
-        # print("a*:", current_node.pos)
-        # for cn in closed_list:
-        #    print("closed:", cn.pos)
-
-        # found the goal
-        if current_node == end_node:
-            path = []
-            current = current_node
-            while current is not None:
-                path.append(current.pos)
-                current = current.parent
-            # return reversed path
-            return path[::-1]
-
-        # generate children
-        children = []
-        if current_node.parent is not None:
-            prev_pos = current_node.parent.pos
-        else:
-            prev_pos = None
-        for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
-            node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1])
-            if node_pos[0] >= rail_shape[0] or \
-                    node_pos[0] < 0 or \
-                    node_pos[1] >= rail_shape[1] or \
-                    node_pos[1] < 0:
-                continue
-
-            # validate positions
-            # debug: avoid all current rails
-            # if rail_array.item(node_pos) != 0:
-            #    continue
-
-            # validate positions
-            if not validate_new_transition(rail_trans, rail_array, prev_pos, current_node.pos, node_pos, end_node.pos):
-                # print("A*: transition invalid")
-                continue
-
-            # create new node
-            new_node = AStarNode(current_node, node_pos)
-            children.append(new_node)
-
-        # loop through children
-        for child in children:
-            # already in closed list?
-            closed_node = is_node_in_list(child, closed_list)
-            if closed_node is not None:
-                continue
-
-            # create the f, g, and h values
-            child.g = current_node.g + 1
-            # this heuristic favors diagonal paths
-            # child.h = ((child.pos[0] - end_node.pos[0]) ** 2) + \
-            #           ((child.pos[1] - end_node.pos[1]) ** 2)
-            # this heuristic avoids diagonal paths
-            child.h = abs(child.pos[0] - end_node.pos[0]) + abs(child.pos[1] - end_node.pos[1])
-            child.f = child.g + child.h
-
-            # already in the open list?
-            open_node = is_node_in_list(child, open_list)
-            if open_node is not None:
-                open_node.update_if_better(child)
-                continue
-
-            # add the child to the open list
-            open_list.append(child)
-
-        # no full path found, return partial path
-        if len(open_list) == 0:
-            path = []
-            current = current_node
-            while current is not None:
-                path.append(current.pos)
-                current = current.parent
-            # return reversed path
-            print("partial:", start, end, path[::-1])
-            return path[::-1]
-
-
-def connect_rail(rail_trans, rail_array, start, end):
-    """
-    Creates a new path [start,end] in rail_array, based on rail_trans.
-    """
-    # in the worst case we will need to do a A* search, so we might as well set that up
-    path = a_star(rail_trans, rail_array, start, end)
-    # print("connecting path", path)
-    if len(path) < 2:
-        return
-    current_dir = get_direction(path[0], path[1])
-    end_pos = path[-1]
-    for index in range(len(path) - 1):
-        current_pos = path[index]
-        new_pos = path[index + 1]
-        new_dir = get_direction(current_pos, new_pos)
-
-        new_trans = rail_array[current_pos]
-        if index == 0:
-            if new_trans == 0:
-                # end-point
-                # need to flip direction because of how end points are defined
-                new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
-            else:
-                # into existing rail
-                new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
-                # new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
-                pass
-        else:
-            # set the forward path
-            new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
-            # set the backwards path
-            new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
-        rail_array[current_pos] = new_trans
-
-        if new_pos == end_pos:
-            # setup end pos setup
-            new_trans_e = rail_array[end_pos]
-            if new_trans_e == 0:
-                # end-point
-                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
-            else:
-                # into existing rail
-                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
-                # new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1)
-            rail_array[end_pos] = new_trans_e
-
-        current_dir = new_dir
-
-
-def distance_on_rail(pos1, pos2):
-    return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
-
-
-def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
-    """
-    Parameters
-    -------
-    width : int
-        The width (number of cells) of the grid to generate.
-    height : int
-        The height (number of cells) of the grid to generate.
-
-    Returns
-    -------
-    numpy.ndarray of type numpy.uint16
-        The matrix with the correct 16-bit bitmaps for each cell.
-    """
-
-    def generator(width, height, num_resets=0):
-        rail_trans = RailEnvTransitions()
-        rail_array = np.zeros(shape=(width, height), dtype=np.uint16)
-
-        np.random.seed(seed + num_resets)
-
-        # generate rail array
-        # step 1:
-        # - generate a list of start and goal positions
-        # - use a min/max distance allowed as input for this
-        # - validate that start/goals are not placed too close to other start/goals
-        #
-        # step 2: (optional)
-        # - place random elements on rails array
-        #   - for instance "train station", etc.
-        #
-        # step 3:
-        # - iterate over all [start, goal] pairs:
-        #   - [first X pairs]
-        #     - draw a rail from [start,goal]
-        #     - draw either vertical or horizontal part first (randomly)
-        #     - if rail crosses existing rail then validate new connection
-        #       - if new connection is invalid turn 90 degrees to left/right
-        #       - possibility that this fails to create a path to goal
-        #         - on failure goto step1 and retry with seed+1
-        #     - [avoid crossing other start,goal positions] (optional)
-        #
-        #   - [after X pairs]
-        #     - find closest rail from start (Pa)
-        #       - iterating outwards in a "circle" from start until an existing rail cell is hit
-        #     - connect [start, Pa]
-        #       - validate crossing rails
-        #     - Do A* from Pa to find closest point on rail (Pb) to goal point
-        #       - Basically normal A* but find point on rail which is closest to goal
-        #       - since full path to goal is unlikely
-        #     - connect [Pb, goal]
-        #       - validate crossing rails
-        #
-        # step 4: (optional)
-        # - add more rails to map randomly
-        #
-        # step 5:
-        # - return transition map + list of [start, goal] points
-        #
-
-        start_goal = []
-        for _ in range(nr_start_goal):
-            sanity_max = 9000
-            for _ in range(sanity_max):
-                start = (np.random.randint(0, width), np.random.randint(0, height))
-                goal = (np.random.randint(0, height), np.random.randint(0, height))
-                # check to make sure start,goal pos is empty?
-                if rail_array[goal] != 0 or rail_array[start] != 0:
-                    continue
-                # check min/max distance
-                dist_sg = distance_on_rail(start, goal)
-                if dist_sg < min_dist:
-                    continue
-                if dist_sg > max_dist:
-                    continue
-                # check distance to existing points
-                sg_new = [start, goal]
-
-                def check_all_dist(sg_new):
-                    for sg in start_goal:
-                        for i in range(2):
-                            for j in range(2):
-                                dist = distance_on_rail(sg_new[i], sg[j])
-                                if dist < 2:
-                                    # print("too close:", dist, sg_new[i], sg[j])
-                                    return False
-                    return True
-
-                if check_all_dist(sg_new):
-                    break
-            start_goal.append([start, goal])
-            connect_rail(rail_trans, rail_array, start, goal)
-
-        print("Created #", len(start_goal), "pairs")
-        # print(start_goal)
-
-        return_rail = GridTransitionMap(width=width, height=height, transitions=rail_trans)
-        return_rail.grid = rail_array
-        # TODO: return start_goal
-        return return_rail
-
-    return generator
-
-
-def rail_from_manual_specifications_generator(rail_spec):
-    """
-    Utility to convert a rail given by manual specification as a map of tuples
-    (cell_type, rotation), to a transition map with the correct 16-bit
-    transitions specifications.
-
-    Parameters
-    -------
-    rail_spec : list of list of tuples
-        List (rows) of lists (columns) of tuples, each specifying a cell for
-        the RailEnv environment as (cell_type, rotation), with rotation being
-        clock-wise and in [0, 90, 180, 270].
-
-    Returns
-    -------
-    function
-        Generator function that always returns a GridTransitionMap object with
-        the matrix of correct 16-bit bitmaps for each cell.
-    """
-
-    def generator(width, height, num_resets=0):
-        t_utils = RailEnvTransitions()
-
-        height = len(rail_spec)
-        width = len(rail_spec[0])
-        rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
-
-        for r in range(height):
-            for c in range(width):
-                cell = rail_spec[r][c]
-                if cell[0] < 0 or cell[0] >= len(t_utils.transitions):
-                    print("ERROR - invalid cell type=", cell[0])
-                    return []
-                rail.set_transitions((r, c), t_utils.rotate_transition(t_utils.transitions[cell[0]], cell[1]))
-
-        return rail
-
-    return generator
-
-
-def rail_from_GridTransitionMap_generator(rail_map):
-    """
-    Utility to convert a rail given by a GridTransitionMap map with the correct
-    16-bit transitions specifications.
-
-    Parameters
-    -------
-    rail_map : GridTransitionMap object
-        GridTransitionMap object to return when the generator is called.
-
-    Returns
-    -------
-    function
-        Generator function that always returns the given `rail_map' object.
-    """
-
-    def generator(width, height, num_resets=0):
-        return rail_map
-
-    return generator
-
-
-def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames):
-    """
-    Utility to sequentially and cyclically return GridTransitionMap-s from a list of files, on each environment reset.
-
-    Parameters
-    -------
-    list_of_filenames : list
-        List of filenames with the saved grids to load.
-
-    Returns
-    -------
-    function
-        Generator function that always returns the given `rail_map' object.
-    """
-
-    def generator(width, height, num_resets=0):
-        t_utils = RailEnvTransitions()
-        rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils)
-        rail_map.load_transition_map(list_of_filenames[num_resets % len(list_of_filenames)], override_gridsize=False)
-
-        if rail_map.grid.dtype == np.uint64:
-            rail_map.transitions = Grid8Transitions()
-
-        return rail_map
-
-    return generator
-
-
-"""
-def generate_rail_from_list_of_manual_specifications(list_of_specifications)
-    def generator(width, height, num_resets=0):
-        return generate_rail_from_manual_specifications(list_of_specifications)
-
-    return generator
-"""
-
-
-def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
-    """
-    Dummy random level generator:
-    - fill in cells at random in [width-2, height-2]
-    - keep filling cells in among the unfilled ones, such that all transitions
-      are legit;  if no cell can be filled in without violating some
-      transitions, pick one among those that can satisfy most transitions
-      (1,2,3 or 4), and delete (+mark to be re-filled) the cells that were
-      incompatible.
-    - keep trying for a total number of insertions
-      (e.g., (W-2)*(H-2)*MAX_REPETITIONS ); if no solution is found, empty the
-      board and try again from scratch.
-    - finally pad the border of the map with dead-ends to avoid border issues.
-
-    Dead-ends are not allowed inside the grid, only at the border; however, if
-    no cell type can be inserted in a given cell (because of the neighboring
-    transitions), deadends are allowed if they solve the problem. This was
-    found to turn most un-genereatable levels into valid ones.
-
-    Parameters
-    -------
-    width : int
-        The width (number of cells) of the grid to generate.
-    height : int
-        The height (number of cells) of the grid to generate.
-
-    Returns
-    -------
-    numpy.ndarray of type numpy.uint16
-        The matrix with the correct 16-bit bitmaps for each cell.
-    """
-
-    def generator(width, height, num_resets=0):
-        t_utils = RailEnvTransitions()
-
-        transition_probability = cell_type_relative_proportion
-
-        transitions_templates_ = []
-        transition_probabilities = []
-        for i in range(len(t_utils.transitions) - 4):  # don't include dead-ends
-            all_transitions = 0
-            for dir_ in range(4):
-                trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
-                all_transitions |= (trans[0] << 3) | \
-                                   (trans[1] << 2) | \
-                                   (trans[2] << 1) | \
-                                   (trans[3])
-
-            template = [int(x) for x in bin(all_transitions)[2:]]
-            template = [0] * (4 - len(template)) + template
-
-            # add all rotations
-            for rot in [0, 90, 180, 270]:
-                transitions_templates_.append((template,
-                                               t_utils.rotate_transition(
-                                                   t_utils.transitions[i],
-                                                   rot)))
-                transition_probabilities.append(transition_probability[i])
-                template = [template[-1]] + template[:-1]
-
-        def get_matching_templates(template):
-            ret = []
-            for i in range(len(transitions_templates_)):
-                is_match = True
-                for j in range(4):
-                    if template[j] >= 0 and template[j] != transitions_templates_[i][0][j]:
-                        is_match = False
-                        break
-                if is_match:
-                    ret.append((transitions_templates_[i][1], transition_probabilities[i]))
-            return ret
-
-        MAX_INSERTIONS = (width - 2) * (height - 2) * 10
-        MAX_ATTEMPTS_FROM_SCRATCH = 10
-
-        attempt_number = 0
-        while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH:
-            cells_to_fill = []
-            rail = []
-            for r in range(height):
-                rail.append([None] * width)
-                if r > 0 and r < height - 1:
-                    cells_to_fill = cells_to_fill + [(r, c) for c in range(1, width - 1)]
-
-            num_insertions = 0
-            while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0:
-                # cell = random.sample(cells_to_fill, 1)[0]
-                cell = cells_to_fill[np.random.choice(len(cells_to_fill), 1)[0]]
-                cells_to_fill.remove(cell)
-                row = cell[0]
-                col = cell[1]
-
-                # look at its neighbors and see what are the possible transitions
-                # that can be chosen from, if any.
-                valid_template = [-1, -1, -1, -1]
-
-                for el in [(0, 2, (-1, 0)),
-                           (1, 3, (0, 1)),
-                           (2, 0, (1, 0)),
-                           (3, 1, (0, -1))]:  # N, E, S, W
-                    neigh_trans = rail[row + el[2][0]][col + el[2][1]]
-                    if neigh_trans is not None:
-                        # select transition coming from facing direction el[1] and
-                        # moving to direction el[1]
-                        max_bit = 0
-                        for k in range(4):
-                            max_bit |= t_utils.get_transition(neigh_trans, k, el[1])
-
-                        if max_bit:
-                            valid_template[el[0]] = 1
-                        else:
-                            valid_template[el[0]] = 0
-
-                possible_cell_transitions = get_matching_templates(valid_template)
-
-                if len(possible_cell_transitions) == 0:  # NO VALID TRANSITIONS
-                    # no cell can be filled in without violating some transitions
-                    # can a dead-end solve the problem?
-                    if valid_template.count(1) == 1:
-                        for k in range(4):
-                            if valid_template[k] == 1:
-                                rot = 0
-                                if k == 0:
-                                    rot = 180
-                                elif k == 1:
-                                    rot = 270
-                                elif k == 2:
-                                    rot = 0
-                                elif k == 3:
-                                    rot = 90
-
-                                rail[row][col] = t_utils.rotate_transition(int('0010000000000000', 2), rot)
-                                num_insertions += 1
-
-                                break
-
-                    else:
-                        # can I get valid transitions by removing a single
-                        # neighboring cell?
-                        bestk = -1
-                        besttrans = []
-                        for k in range(4):
-                            tmp_template = valid_template[:]
-                            tmp_template[k] = -1
-                            possible_cell_transitions = get_matching_templates(tmp_template)
-                            if len(possible_cell_transitions) > len(besttrans):
-                                besttrans = possible_cell_transitions
-                                bestk = k
-
-                        if bestk >= 0:
-                            # Replace the corresponding cell with None, append it
-                            # to cells to fill, fill in a transition in the current
-                            # cell.
-                            replace_row = row - 1
-                            replace_col = col
-                            if bestk == 1:
-                                replace_row = row
-                                replace_col = col + 1
-                            elif bestk == 2:
-                                replace_row = row + 1
-                                replace_col = col
-                            elif bestk == 3:
-                                replace_row = row
-                                replace_col = col - 1
-
-                            cells_to_fill.append((replace_row, replace_col))
-                            rail[replace_row][replace_col] = None
-
-                            possible_transitions, possible_probabilities = zip(*besttrans)
-                            possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities]
-
-                            rail[row][col] = np.random.choice(possible_transitions,
-                                                              p=possible_probabilities)
-                            num_insertions += 1
-
-                        else:
-                            print('WARNING: still nothing!')
-                            rail[row][col] = int('0000000000000000', 2)
-                            num_insertions += 1
-                            pass
-
-                else:
-                    possible_transitions, possible_probabilities = zip(*possible_cell_transitions)
-                    possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities]
-
-                    rail[row][col] = np.random.choice(possible_transitions,
-                                                      p=possible_probabilities)
-                    num_insertions += 1
-
-            if num_insertions == MAX_INSERTIONS:
-                # Failed to generate a valid level; try again for a number of times
-                attempt_number += 1
-            else:
-                break
-
-        if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH:
-            print('ERROR: failed to generate level')
-
-        # Finally pad the border of the map with dead-ends to avoid border issues;
-        # at most 1 transition in the neigh cell
-        for r in range(height):
-            # Check for transitions coming from [r][1] to WEST
-            max_bit = 0
-            neigh_trans = rail[r][1]
-            if neigh_trans is not None:
-                for k in range(4):
-                    neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
-                    max_bit = max_bit | (neigh_trans_from_direction & 1)
-            if max_bit:
-                rail[r][0] = t_utils.rotate_transition(int('0010000000000000', 2), 270)
-            else:
-                rail[r][0] = int('0000000000000000', 2)
-
-            # Check for transitions coming from [r][-2] to EAST
-            max_bit = 0
-            neigh_trans = rail[r][-2]
-            if neigh_trans is not None:
-                for k in range(4):
-                    neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
-                    max_bit = max_bit | (neigh_trans_from_direction & (1 << 2))
-            if max_bit:
-                rail[r][-1] = t_utils.rotate_transition(int('0010000000000000', 2),
-                                                        90)
-            else:
-                rail[r][-1] = int('0000000000000000', 2)
-
-        for c in range(width):
-            # Check for transitions coming from [1][c] to NORTH
-            max_bit = 0
-            neigh_trans = rail[1][c]
-            if neigh_trans is not None:
-                for k in range(4):
-                    neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
-                    max_bit = max_bit | (neigh_trans_from_direction & (1 << 3))
-            if max_bit:
-                rail[0][c] = int('0010000000000000', 2)
-            else:
-                rail[0][c] = int('0000000000000000', 2)
-
-            # Check for transitions coming from [-2][c] to SOUTH
-            max_bit = 0
-            neigh_trans = rail[-2][c]
-            if neigh_trans is not None:
-                for k in range(4):
-                    neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
-                    max_bit = max_bit | (neigh_trans_from_direction & (1 << 1))
-            if max_bit:
-                rail[-1][c] = t_utils.rotate_transition(int('0010000000000000', 2), 180)
-            else:
-                rail[-1][c] = int('0000000000000000', 2)
-
-        # For display only, wrong levels
-        for r in range(height):
-            for c in range(width):
-                if rail[r][c] is None:
-                    rail[r][c] = int('0000000000000000', 2)
-
-        tmp_rail = np.asarray(rail, dtype=np.uint16)
-
-        return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
-        return_rail.grid = tmp_rail
-        return return_rail
-
-    return generator
+# from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
+# from flatland.core.transition_map import GridTransitionMap
 
 
 class EnvAgentStatic(object):
diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py
index ebaf905bed2c92993da01f7d7b02c353b1ad593f..5f6625aa77216702e09cc4c2a29a2f37ee753027 100644
--- a/flatland/utils/editor.py
+++ b/flatland/utils/editor.py
@@ -167,26 +167,32 @@ class JupEditor(object):
                 # get the direction index for the 2 transitions
                 liTrans = []
                 for rcTrans in rc2Trans:
+                    # gRCTrans - rcTrans gives an array of vector differences between our rcTrans 
+                    # and the 4 directions stored in gRCTrans.
+                    # Where the vector difference is zero, we have a match...
+                    # np.all detects where the whole row,col vector is zero.
+                    # argwhere gives the index of the zero vector, ie the direction index
                     iTrans = np.argwhere(np.all(self.gRCTrans - rcTrans == 0, axis=1))
                     if len(iTrans) > 0:
                         iTrans = iTrans[0][0]
                         liTrans.append(iTrans)
 
+                # check that we have two transitions 
                 if len(liTrans) == 2:
                     # Set the transition
-                    # oEnv.rail.set_transition((*rcLast, iTransLast), iTrans, True) # does nothing
-                    iValCell = env.rail.transitions.set_transition(
-                        env.rail.grid[tuple(rcMiddle)], liTrans[0], liTrans[1], bTransition)
+                    env.rail.set_transition((*rcMiddle, liTrans[0]), liTrans[1], True)
+                    # iValCell = env.rail.transitions.set_transition(
+                    #    env.rail.grid[tuple(rcMiddle)], liTrans[0], liTrans[1], bTransition)
 
                     # Also set the reverse transition
-                    iValCell = env.rail.transitions.set_transition(
-                        iValCell,
-                        (liTrans[1] + 2) % 4,
-                        (liTrans[0] + 2) % 4,
-                        bTransition)
+                    # iValCell = env.rail.transitions.set_transition(
+                    #    iValCell,
+                    #    (liTrans[1] + 2) % 4, # use the reversed outbound transition for inbound
+                    #    (liTrans[0] + 2) % 4, # use the reversed inbound transition for outbound
+                    #    bTransition)
 
                     # Write the cell transition value back into the grid
-                    env.rail.grid[tuple(rcMiddle)] = iValCell
+                    # env.rail.grid[tuple(rcMiddle)] = iValCell
             
                 rcHistory.pop(0)  # remove the last-but-one
             
diff --git a/tests/test_env_observation_builder.py b/tests/test_env_observation_builder.py
index 55c229e88e73c311ae8f8f4aeee01218cf1dd4cf..9ec0db0a1ee5ab7dd5b235448298cee601af28f5 100644
--- a/tests/test_env_observation_builder.py
+++ b/tests/test_env_observation_builder.py
@@ -5,7 +5,8 @@ import numpy as np
 
 from flatland.core.env_observation_builder import GlobalObsForRailEnv
 from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
-from flatland.envs.rail_env import RailEnv, rail_from_GridTransitionMap_generator
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.generators import rail_from_GridTransitionMap_generator
 
 """Tests for `flatland` package."""
 
diff --git a/tests/test_environments.py b/tests/test_environments.py
index b46bb38828285401a85dee9000fd5935819c9342..a10fb0619eae4d27867c9008c27618fe059d52d2 100644
--- a/tests/test_environments.py
+++ b/tests/test_environments.py
@@ -2,11 +2,13 @@
 # -*- coding: utf-8 -*-
 import numpy as np
 
-from flatland.envs.rail_env import RailEnv, rail_from_GridTransitionMap_generator
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.generators import rail_from_GridTransitionMap_generator
 from flatland.core.transitions import Grid4Transitions
 from flatland.core.transition_map import GridTransitionMap
 from flatland.core.env_observation_builder import GlobalObsForRailEnv
 
+
 """Tests for `flatland` package."""
 
 
diff --git a/tests/test_transitions.py b/tests/test_transitions.py
index 2ebfc462cd62bee167b5c9f742d159e8567ed8b4..69a7953c5b50bc5411c52b58c33448228c91620f 100644
--- a/tests/test_transitions.py
+++ b/tests/test_transitions.py
@@ -3,7 +3,8 @@
 
 """Tests for `flatland` package."""
 from flatland.core.transitions import RailEnvTransitions, Grid8Transitions
-from flatland.envs.rail_env import validate_new_transition
+# from flatland.envs.rail_env import validate_new_transition
+from flatland.envs.env_utils import validate_new_transition
 import numpy as np