diff --git a/examples/play_model.py b/examples/play_model.py
index 2c18c3e3fbf5e54320f3382ae158f542a2130080..8f0df7cdd7957c432515d04503281d45e3bbdbc2 100644
--- a/examples/play_model.py
+++ b/examples/play_model.py
@@ -1,4 +1,4 @@
-from flatland.envs.rail_env import RailEnv, random_rail_generator
+from flatland.envs.rail_env import RailEnv, random_rail_generator, complex_rail_generator
 # from flatland.core.env_observation_builder import TreeObsForRailEnv
 from flatland.utils.rendertools import RenderTool
 from flatland.baselines.dueling_double_dqn import Agent
@@ -6,7 +6,7 @@ from collections import deque
 import torch
 import random
 import numpy as np
-import matplotlib.pyplot as plt
+#import matplotlib.pyplot as plt
 import time
 
 
@@ -94,24 +94,23 @@ def main(render=True, delay=0.0):
 
     # Example generate a rail given a manual specification,
     # a map of tuples (cell_type, rotation)
-    transition_probability = [0.5,  # empty cell - Case 0
-                            1.0,  # Case 1 - straight
-                            1.0,  # Case 2 - simple switch
-                            0.3,  # Case 3 - diamond drossing
-                            0.5,  # Case 4 - single slip
-                            0.5,  # Case 5 - double slip
-                            0.2,  # Case 6 - symmetrical
-                            0.0]  # Case 7 - dead end
+    #transition_probability = [0.5,  # empty cell - Case 0
+    #                        1.0,  # Case 1 - straight
+    #                        1.0,  # Case 2 - simple switch
+    #                        0.3,  # Case 3 - diamond crossing
+    #                        0.5,  # Case 4 - single slip
+    #                        0.5,  # Case 5 - double slip
+    #                        0.2,  # Case 6 - symmetrical
+    #                        0.0]  # Case 7 - dead end
 
     # Example generate a random rail
-    env = RailEnv(width=15,
-                height=15,
-                rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
-                number_of_agents=5)
+    env = RailEnv(width=15, height=15,
+                  rail_generator=complex_rail_generator(nr_start_goal=20, min_dist=5),
+                  number_of_agents=1)
 
     if render:
         env_renderer = RenderTool(env, gl="QT")
-    plt.figure(figsize=(5,5))
+    # plt.figure(figsize=(5,5))
     # fRedis = redis.Redis()
 
     handle = env.get_agent_handles()
diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py
index a8cb8d6f49157bafbb65551b53c6612c45565c88..ec9586ed44986bf3d24f352c336d9eb64a074a99 100644
--- a/flatland/core/transitions.py
+++ b/flatland/core/transitions.py
@@ -531,9 +531,47 @@ class RailEnvTransitions(Grid4Transitions):
                        int('1001011000100001', 2),  # Case 4 - single slip
                        int('1100110000110011', 2),  # Case 5 - double slip
                        int('0101001000000010', 2),  # Case 6 - symmetrical
-                       int('0010000000000000', 2)]  # Case 7 - dead end
+                       int('0010000000000000', 2),  # Case 7 - dead end
+                       int('0100000000000010', 2),  # Case 1b - simple turn right
+                       int('0001001000000000', 2)]  # Case 1c - simple turn left
 
     def __init__(self):
         super(RailEnvTransitions, self).__init__(
             transitions=self.transition_list
         )
+
+    def print(self, cell_transition):
+        print("  NESW")
+        print("N", format(cell_transition>>(3*4) & 0xF, '04b'))
+        print("E", format(cell_transition>>(2*4) & 0xF, '04b'))
+        print("S", format(cell_transition>>(1*4) & 0xF, '04b'))
+        print("W", format(cell_transition>>(0*4) & 0xF, '04b'))
+
+    def is_valid(self, cell_transition):
+        """
+        Checks if a cell transition is a valid cell setup.
+
+        Parameters
+        ----------
+        cell_transition : int
+            64 bits used to encode the valid transitions for a cell.
+
+        Returns
+        -------
+        Boolean
+            True or False
+        """
+        # i = 0
+        for trans in self.transitions:
+            # print(">", i)
+            # i += 1
+            # self.print(trans)
+            if cell_transition == trans:
+                return True
+            for _ in range(3):
+                trans = self.rotate_transition(trans, rotation=90)
+                # self.print(trans)
+                if cell_transition == trans:
+                    return True
+        return False
+
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index ce05ce02ccaec6ca5a8add3adef24fdcead02924..3672caf43d69843bdd88dfa5baed4b2d206a1be8 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -13,6 +13,333 @@ from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 
 
+class AStarNode():
+    """A node class for A* Pathfinding"""
+
+    def __init__(self, parent=None, pos=None):
+        self.parent = parent
+        self.pos = pos
+        self.g = 0
+        self.h = 0
+        self.f = 0
+
+    def __eq__(self, other):
+        return self.pos == other.pos
+
+    def update_if_better(self, other):
+        if other.g < self.g:
+            self.parent = other.parent
+            self.g = other.g
+            self.h = other.h
+            self.f = other.f
+
+
+def get_direction(pos1, pos2):
+    """
+    Assumes pos1 and pos2 are adjacent location on grid.
+    Returns direction (int) that can be used with transitions.
+    """
+    diff_0 = pos2[0] - pos1[0]
+    diff_1 = pos2[1] - pos1[1]
+    if diff_0 < 0:
+        return 0
+    if diff_0 > 0:
+        return 2
+    if diff_1 > 0:
+        return 1
+    if diff_1 < 0:
+        return 3
+    return 0
+
+
+def mirror(dir):
+    return (dir + 2) % 4
+
+
+def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_pos, end_pos):
+    # start by getting direction used to get to current node
+    # and direction from current node to possible child node
+    new_dir = get_direction(current_pos, new_pos)
+    if prev_pos is not None:
+        current_dir = get_direction(prev_pos, current_pos)
+    else:
+        current_dir = new_dir
+    # create new transition that would go to child
+    new_trans = rail_array[current_pos]
+    if prev_pos is None:
+        # need to flip direction because of how end points are defined
+        new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
+    else:
+        # set the forward path
+        new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+        # set the backwards path
+        new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
+    if new_pos == end_pos:
+        # need to validate end pos setup as well
+        new_trans_e = rail_array[end_pos]
+        new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
+        # print("========> end trans")
+        # rail_trans.print(new_trans_e)
+        if not rail_trans.is_valid(new_trans_e):
+            return False
+    # is transition is valid?
+    # print("=======> trans")
+    # rail_trans.print(new_trans)
+    return rail_trans.is_valid(new_trans)
+
+
+def a_star(rail_trans, rail_array, start, end):
+    """
+    Returns a list of tuples as a path from the given start to end.
+    If no path is found, returns path to closest point to end.
+    """
+    rail_shape = rail_array.shape
+    start_node = AStarNode(None, start)
+    end_node = AStarNode(None, end)
+    open_list = []
+    closed_list = []
+
+    open_list.append(start_node)
+
+    # this could be optimized
+    def is_node_in_list(node, the_list):
+        for o_node in the_list:
+            if node == o_node:
+                return o_node
+        return None
+
+    while len(open_list) > 0:
+        # get node with current shortest est. path (lowest f)
+        current_node = open_list[0]
+        current_index = 0
+        for index, item in enumerate(open_list):
+            if item.f < current_node.f:
+                current_node = item
+                current_index = index
+
+        # pop current off open list, add to closed list
+        open_list.pop(current_index)
+        closed_list.append(current_node)
+
+        # print("a*:", current_node.pos)
+        # for cn in closed_list:
+        #    print("closed:", cn.pos)
+
+        # found the goal
+        if current_node == end_node:
+            path = []
+            current = current_node
+            while current is not None:
+                path.append(current.pos)
+                current = current.parent
+            # return reversed path
+            return path[::-1]
+
+        # generate children
+        children = []
+        if current_node.parent is not None:
+            prev_pos = current_node.parent.pos
+        else:
+            prev_pos = None
+        for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
+            node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1])
+            if node_pos[0] >= rail_shape[0] or \
+               node_pos[0] < 0 or \
+               node_pos[1] >= rail_shape[1] or \
+               node_pos[1] < 0:
+                continue
+
+            # validate positions
+            # debug: avoid all current rails
+            # if rail_array.item(node_pos) != 0:
+            #    continue
+
+            # validate positions
+            if not validate_new_transition(rail_trans, rail_array, prev_pos, current_node.pos, node_pos, end_node.pos):
+                # print("A*: transition invalid")
+                continue
+
+            # create new node
+            new_node = AStarNode(current_node, node_pos)
+            children.append(new_node)
+
+        # loop through children
+        for child in children:
+            # already in closed list?
+            closed_node = is_node_in_list(child, closed_list)
+            if closed_node is not None:
+                continue
+
+            # create the f, g, and h values
+            child.g = current_node.g + 1
+            # this heuristic favors diagonal paths
+            # child.h = ((child.pos[0] - end_node.pos[0]) ** 2) + \
+            #           ((child.pos[1] - end_node.pos[1]) ** 2)
+            # this heuristic avoids diagonal paths
+            child.h = abs(child.pos[0] - end_node.pos[0]) + abs(child.pos[1] - end_node.pos[1])
+            child.f = child.g + child.h
+
+            # already in the open list?
+            open_node = is_node_in_list(child, open_list)
+            if open_node is not None:
+                open_node.update_if_better(child)
+                continue
+
+            # add the child to the open list
+            open_list.append(child)
+
+        # no full path found, return partial path
+        if len(open_list) == 0:
+            path = []
+            current = current_node
+            while current is not None:
+                path.append(current.pos)
+                current = current.parent
+            # return reversed path
+            return path[::-1]
+
+
+def connect_rail(rail_trans, rail_array, start, end):
+    """
+    Creates a new path [start,end] in rail_array, based on rail_trans.
+    """
+    # in the worst case we will need to do a A* search, so we might as well set that up
+    path = a_star(rail_trans, rail_array, start, end)
+    # print("connecting path", path)
+    if len(path) < 2:
+        return
+    current_dir = get_direction(path[0], path[1])
+    end_pos = path[-1]
+    for index in range(len(path) - 1):
+        current_pos = path[index]
+        new_pos = path[index+1]
+        new_dir = get_direction(current_pos, new_pos)
+
+        new_trans = rail_array[current_pos]
+        if index == 0:
+            # need to flip direction because of how end points are defined
+            new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
+        else:
+            # set the forward path
+            new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
+            # set the backwards path
+            new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
+        rail_array[current_pos] = new_trans
+
+        if new_pos == end_pos:
+            # need to validate end pos setup as well
+            new_trans_e = rail_array[end_pos]
+            new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
+            rail_array[end_pos] = new_trans_e
+
+        current_dir = new_dir
+
+
+def distance_on_rail(pos1, pos2):
+    return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
+
+
+def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
+    """
+    Parameters
+    -------
+    width : int
+        The width (number of cells) of the grid to generate.
+    height : int
+        The height (number of cells) of the grid to generate.
+
+    Returns
+    -------
+    numpy.ndarray of type numpy.uint16
+        The matrix with the correct 16-bit bitmaps for each cell.
+    """
+
+    def generator(width, height, num_resets=0):
+        rail_trans = RailEnvTransitions()
+        rail_array = np.zeros(shape=(width, height), dtype=np.uint16)
+
+        np.random.seed(seed + num_resets)
+
+        # generate rail array
+        # step 1:
+        # - generate a list of start and goal positions
+        # - use a min/max distance allowed as input for this
+        # - validate that start/goals are not placed too close to other start/goals
+        #
+        # step 2: (optional)
+        # - place random elements on rails array
+        #   - for instance "train station", etc.
+        #
+        # step 3:
+        # - iterate over all [start, goal] pairs:
+        #   - [first X pairs]
+        #     - draw a rail from [start,goal]
+        #     - draw either vertical or horizontal part first (randomly)
+        #     - if rail crosses existing rail then validate new connection
+        #       - if new connection is invalid turn 90 degrees to left/right
+        #       - possibility that this fails to create a path to goal
+        #         - on failure goto step1 and retry with seed+1
+        #     - [avoid crossing other start,goal positions] (optional)
+        #
+        #   - [after X pairs] 
+        #     - find closest rail from start (Pa)
+        #       - iterating outwards in a "circle" from start until an existing rail cell is hit
+        #     - connect [start, Pa]
+        #       - validate crossing rails
+        #     - Do A* from Pa to find closest point on rail (Pb) to goal point
+        #       - Basically normal A* but find point on rail which is closest to goal
+        #       - since full path to goal is unlikely
+        #     - connect [Pb, goal]
+        #       - validate crossing rails
+        #
+        # step 4: (optional)
+        # - add more rails to map randomly
+        #
+        # step 5:
+        # - return transition map + list of [start, goal] points
+        #
+
+        # step 1:
+        start_goal = []
+        for _ in range(nr_start_goal):
+            sanity_max = 9000
+            for _ in range(sanity_max):
+                start = (np.random.randint(0, width), np.random.randint(0, height))
+                goal = (np.random.randint(0, height), np.random.randint(0, height))
+                # check min/max distance
+                dist_sg = distance_on_rail(start, goal)
+                if dist_sg < min_dist:
+                    continue
+                if dist_sg > max_dist:
+                    continue
+                # check distance to existing points
+                sg_new = [start, goal]
+                def check_all_dist(sg_new):
+                    for sg in start_goal:
+                        for i in range(2):
+                            for j in range(2):
+                                dist = distance_on_rail(sg_new[i], sg[j])
+                                if dist < 2:
+                                    # print("too close:", dist, sg_new[i], sg[j])
+                                    return False
+                    return True
+                if check_all_dist(sg_new):
+                    break
+            start_goal.append([start, goal])
+        print("Created #", len(start_goal), "pairs")
+
+        # step 3:
+        for sg in start_goal:
+            connect_rail(rail_trans, rail_array, sg[0], sg[1])
+
+        return_rail = GridTransitionMap(width=width, height=height, transitions=rail_trans)
+        return_rail.grid = rail_array
+        # TODO: return start_goal
+        return return_rail
+
+    return generator
+
+
 def rail_from_manual_specifications_generator(rail_spec):
     """
     Utility to convert a rail given by manual specification as a map of tuples
@@ -148,7 +475,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
 
         transitions_templates_ = []
         transition_probabilities = []
-        for i in range(len(t_utils.transitions) - 1):  # don't include dead-ends
+        for i in range(len(t_utils.transitions) - 3):  # don't include dead-ends
             all_transitions = 0
             for dir_ in range(4):
                 trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
diff --git a/tests/test_transitions.py b/tests/test_transitions.py
index 0f56e886071fd1d217be03b9a7e875c20d1a0e8a..1d6ea966a318b5dda7a6e61c87343fc0710e72ea 100644
--- a/tests/test_transitions.py
+++ b/tests/test_transitions.py
@@ -3,6 +3,56 @@
 
 """Tests for `flatland` package."""
 from flatland.core.transitions import RailEnvTransitions, Grid8Transitions
+from flatland.envs.rail_env import validate_new_transition
+import numpy as np
+
+
+def test_is_valid_railenv_transitions():
+    rail_env_trans = RailEnvTransitions()
+    transition_list = rail_env_trans.transitions
+
+    for t in transition_list:
+        assert(rail_env_trans.is_valid(t) == True)
+        for i in range(3):
+            rot_trans = rail_env_trans.rotate_transition(t, 90 * i)
+            assert(rail_env_trans.is_valid(rot_trans) == True)
+
+    assert(rail_env_trans.is_valid(int('1111111111110010', 2)) == False)
+    assert(rail_env_trans.is_valid(int('1001111111110010', 2)) == False)
+    assert(rail_env_trans.is_valid(int('1001111001110110', 2)) == False)
+
+
+def test_adding_new_valid_transition():
+    rail_trans = RailEnvTransitions()
+    rail_array = np.zeros(shape=(15, 15), dtype=np.uint16)
+
+    # adding straight
+    assert(validate_new_transition(rail_trans, rail_array, (4,5), (5,5), (6,5), (10,10)) == True)
+
+    # adding valid right turn
+    assert(validate_new_transition(rail_trans, rail_array, (5,4), (5,5), (5,6), (10,10)) == True)
+    # adding valid left turn
+    assert(validate_new_transition(rail_trans, rail_array, (5,6), (5,5), (5,6), (10,10)) == True)
+
+    # adding invalid turn
+    rail_array[(5,5)] = rail_trans.transitions[2]
+    assert(validate_new_transition(rail_trans, rail_array, (4,5), (5,5), (5,6), (10,10)) == False)
+
+    # should create #4 -> valid
+    rail_array[(5,5)] = rail_trans.transitions[3]
+    assert(validate_new_transition(rail_trans, rail_array, (4,5), (5,5), (5,6), (10,10)) == True)
+
+    # adding invalid turn
+    rail_array[(5,5)] = rail_trans.transitions[7]
+    assert(validate_new_transition(rail_trans, rail_array, (4,5), (5,5), (5,6), (10,10)) == False)
+
+    # test path start condition
+    rail_array[(5,5)] = rail_trans.transitions[0]
+    assert(validate_new_transition(rail_trans, rail_array, None, (5,5), (5,6), (10,10)) == True)
+
+    # test path end condition
+    rail_array[(5,5)] = rail_trans.transitions[0]
+    assert(validate_new_transition(rail_trans, rail_array, (5,4), (5,5), (6,5), (6,5)) == True)
 
 
 def test_valid_railenv_transitions():