diff --git a/examples/custom_railmap_example.py b/examples/custom_railmap_example.py
index 04da66904fda1a58847a4acc510d7fc4e4e86887..6162b918734eb311752675e75e203d90e5558c1c 100644
--- a/examples/custom_railmap_example.py
+++ b/examples/custom_railmap_example.py
@@ -43,6 +43,7 @@ def custom_schedule_generator() -> ScheduleGenerator:
 env = RailEnv(width=6,
               height=4,
               rail_generator=custom_rail_generator(),
+              schedule_generator=custom_schedule_generator(),
               number_of_agents=1)
 
 env.reset()
diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py
index 7a673bcf9ba46c574db0983e3c52257e4a07358e..bb954998688772a7ce69e5228cff3e16d037f2af 100644
--- a/flatland/core/transition_map.py
+++ b/flatland/core/transition_map.py
@@ -7,6 +7,7 @@ from importlib_resources import path
 from numpy import array
 
 from flatland.core.grid.grid4 import Grid4Transitions
+from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.transitions import Transitions
 
 
@@ -298,6 +299,76 @@ class GridTransitionMap(TransitionMap):
         self.height = new_height
         self.grid = new_grid
 
+    def is_dead_end(self, rcPos):
+        """
+        Check if the cell is a dead-end.
+
+        Parameters
+        ----------
+        rcPos: Tuple[int,int]
+            tuple(row, column) with grid coordinate
+        Returns
+        -------
+        boolean
+            True if and only if the cell is a dead-end.
+        """
+        nbits = 0
+        tmp = self.get_full_transitions(rcPos[0], rcPos[1])
+        while tmp > 0:
+            nbits += (tmp & 1)
+            tmp = tmp >> 1
+        return nbits == 1
+
+    def is_simple_turn(self, rcPos):
+        """
+        Check if the cell is a left/right simple turn
+
+        Parameters
+        ----------
+            rcPos: Tuple[int,int]
+                tuple(row, column) with grid coordinate
+        Returns
+        -------
+            boolean
+                True if and only if the cell is a left/right simple turn.
+        """
+        tmp = self.get_full_transitions(rcPos[0], rcPos[1])
+
+        def is_simple_turn(trans):
+            all_simple_turns = set()
+            for trans in [int('0100000000000010', 2),  # Case 1b (8)  - simple turn right
+                          int('0001001000000000', 2)  # Case 1c (9)  - simple turn left]:
+                          ]:
+                for _ in range(3):
+                    trans = self.transitions.rotate_transition(trans, rotation=90)
+                    all_simple_turns.add(trans)
+            return trans in all_simple_turns
+
+        return is_simple_turn(tmp)
+
+    def check_path_exists(self, start, direction, end):
+        # print("_path_exists({},{},{}".format(start, direction, end))
+        # BFS - Check if a path exists between the 2 nodes
+
+        visited = set()
+        stack = [(start, direction)]
+        while stack:
+            node = stack.pop()
+            node_position = node[0]
+            node_direction = node[1]
+            if node_position[0] == end[0] and node_position[1] == end[1]:
+                return True
+            if node not in visited:
+                visited.add(node)
+
+                moves = self.get_transitions(node_position[0], node_position[1], node_direction)
+                for move_index in range(4):
+                    if moves[move_index]:
+                        stack.append((get_new_position(node_position, move_index),
+                                      move_index))
+
+        return False
+
     def cell_neighbours_valid(self, rcPos, check_this_cell=False):
         """
         Check validity of cell at rcPos = tuple(row, column)
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 4158675cae63394ed768bfc36aaef9cd5f44da7e..c4fed2e07a97b00b786a7db8eb06af247d3ede8a 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -383,8 +383,9 @@ class TreeObsForRailEnv(ObservationBuilder):
                     elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0):
                         conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
                         for ca in conflicting_agent[0]:
-                            if direction != self.predicted_dir[pre_step][ca] and cell_transitions[self._reverse_dir(
-                                self.predicted_dir[pre_step][ca])] == 1 and tot_dist < potential_conflict:
+                            if direction != self.predicted_dir[pre_step][ca] \
+                                and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \
+                                and tot_dist < potential_conflict:  # noqa: E125
                                 potential_conflict = tot_dist
                             if self.env.dones[ca] and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
@@ -394,7 +395,8 @@ class TreeObsForRailEnv(ObservationBuilder):
                         conflicting_agent = np.where(self.predicted_pos[post_step] == int_position)
                         for ca in conflicting_agent[0]:
                             if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir(
-                                self.predicted_dir[post_step][ca])] == 1 and tot_dist < potential_conflict:
+                                self.predicted_dir[post_step][ca])] == 1 \
+                                and tot_dist < potential_conflict:  # noqa: E125
                                 potential_conflict = tot_dist
                             if self.env.dones[ca] and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index a61ef02207174d04489b5311dc042b7c06db1412..62efbdc5bc0781c4c7482412dafd98710ed9d14e 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -271,8 +271,7 @@ class RailEnv(Environment):
         agent.malfunction_data['next_malfunction'] -= 1
 
         # Only agents that have a positive rate for malfunctions and are not currently broken are considered
-        if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data[
-            'malfunction']:
+        if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data['malfunction']:
 
             # If counter has come to zero --> Agent has malfunction
             # set next malfunction time and duration of current malfunction
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 8f2c5231232c3cc17f50de4b89b88f1c9fdc5d60..40ec2e0df89a3de48bf0a2a4430de3e30fa556e5 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -683,8 +683,10 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
                                     0,
                                     width - 1)
                 tries = 0
-                while (station_x, station_y) in train_stations or (station_x, station_y) == node_positions[
-                    trainstation_node] or rail_array[(station_x, station_y)] != 0:
+                while (station_x, station_y) in train_stations \
+                    or (station_x, station_y) == node_positions[trainstation_node] \
+                    or rail_array[(station_x, station_y)] != 0:  # noqa: E125
+
                     station_x = np.clip(
                         node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius),
                         0,
diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py
index 2ef6dab85fc6145bbfb57b25994903bbd861f65f..4843e0040d80b79de54e8ed57674a37884ef6809 100644
--- a/flatland/envs/schedule_generators.py
+++ b/flatland/envs/schedule_generators.py
@@ -131,34 +131,6 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) ->
     """
 
     def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct:
-        def _path_exists(rail, start, direction, end):
-            # BFS - Check if a path exists between the 2 nodes
-
-            visited = set()
-            stack = [(start, direction)]
-            while stack:
-                node = stack.pop()
-                if node[0][0] == end[0] and node[0][1] == end[1]:
-                    return 1
-                if node not in visited:
-                    visited.add(node)
-                    moves = rail.get_transitions(node[0][0], node[0][1], node[1])
-                    for move_index in range(4):
-                        if moves[move_index]:
-                            stack.append((get_new_position(node[0], move_index),
-                                          move_index))
-
-                    # If cell is a dead-end, append previous node with reversed
-                    # orientation!
-                    nbits = 0
-                    tmp = rail.get_full_transitions(node[0][0], node[0][1])
-                    while tmp > 0:
-                        nbits += (tmp & 1)
-                        tmp = tmp >> 1
-                    if nbits == 1:
-                        stack.append((node[0], (node[1] + 2) % 4))
-
-            return 0
 
         valid_positions = []
         for r in range(rail.height):
@@ -167,14 +139,35 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) ->
                     valid_positions.append((r, c))
         if len(valid_positions) == 0:
             return [], [], [], []
+
+        if len(valid_positions) < num_agents:
+            warnings.warn("schedule_generators: len(valid_positions) < num_agents")
+            return [], [], [], []
+
+        agents_position_idx = [i for i in np.random.choice(len(valid_positions), num_agents, replace=False)]
+        agents_position = [valid_positions[agents_position_idx[i]] for i in range(num_agents)]
+        agents_target_idx = [i for i in np.random.choice(len(valid_positions), num_agents, replace=False)]
+        agents_target = [valid_positions[agents_target_idx[i]] for i in range(num_agents)]
+        update_agents = np.zeros(num_agents)
+
         re_generate = True
+        cnt = 0
         while re_generate:
-            agents_position = [
-                valid_positions[i] for i in
-                np.random.choice(len(valid_positions), num_agents)]
-            agents_target = [
-                valid_positions[i] for i in
-                np.random.choice(len(valid_positions), num_agents)]
+            cnt += 1
+            if cnt > 1:
+                print("re_generate cnt={}".format(cnt))
+            if cnt > 1000:
+                raise Exception("After 1000 re_generates still not success, giving up.")
+            # update position
+            for i in range(num_agents):
+                if update_agents[i] == 1:
+                    x = np.setdiff1d(np.arange(len(valid_positions)), agents_position_idx)
+                    agents_position_idx[i] = np.random.choice(x)
+                    agents_position[i] = valid_positions[agents_position_idx[i]]
+                    x = np.setdiff1d(np.arange(len(valid_positions)), agents_target_idx)
+                    agents_target_idx[i] = np.random.choice(x)
+                    agents_target[i] = valid_positions[agents_target_idx[i]]
+            update_agents = np.zeros(num_agents)
 
             # agents_direction must be a direction for which a solution is
             # guaranteed.
@@ -192,12 +185,15 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) ->
                 valid_starting_directions = []
                 for m in valid_movements:
                     new_position = get_new_position(agents_position[i], m[1])
-                    if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0],
-                                                                              agents_target[i]):
+                    if m[0] not in valid_starting_directions and rail.check_path_exists(new_position, m[1],
+                                                                                        agents_target[i]):
                         valid_starting_directions.append(m[0])
 
                 if len(valid_starting_directions) == 0:
+                    update_agents[i] = 1
+                    warnings.warn("reset position for agent[{}]: {} -> {}".format(i, agents_position[i], agents_target[i]))
                     re_generate = True
+                    break
                 else:
                     agents_direction[i] = valid_starting_directions[
                         np.random.choice(len(valid_starting_directions), 1)[0]]
diff --git a/flatland/utils/simple_rail.py b/flatland/utils/simple_rail.py
index c5fe4860783f242f21c97c55a9119d8918454a96..67bd93dd35c8f53ef3cdef23dbae0f0d785b9a64 100644
--- a/flatland/utils/simple_rail.py
+++ b/flatland/utils/simple_rail.py
@@ -81,6 +81,43 @@ def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]:
     rail.grid = rail_map
     return rail, rail_map
 
+def make_simple_rail_unconnected() -> Tuple[GridTransitionMap, np.array]:
+    # We instantiate a very simple rail network on a 7x10 grid:
+    # Note that that cells have invalid RailEnvTransitions!
+    #        |
+    #        |
+    #        |
+    # _ _ _  _ _ _  _  _ _ _
+    #                /
+    #                |
+    #                |
+    #                |
+    transitions = RailEnvTransitions()
+    cells = transitions.transition_list
+    empty = cells[0]
+    dead_end_from_south = cells[7]
+    dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
+    dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
+    dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
+    vertical_straight = cells[1]
+    horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
+    simple_switch_north_left = cells[2]
+    simple_switch_north_right = cells[10]
+    simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270)
+    simple_switch_east_west_south = transitions.rotate_transition(simple_switch_north_left, 270)
+    rail_map = np.array(
+        [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
+        [[empty] * 3 + [vertical_straight] + [empty] * 6]  +
+        [[empty] * 3 + [dead_end_from_north] + [empty] * 6]  +
+        [[dead_end_from_east] + [horizontal_straight]  * 5 + [simple_switch_east_west_south] +
+         [horizontal_straight] * 2 + [dead_end_from_west]] +
+        [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
+        [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
+    rail = GridTransitionMap(width=rail_map.shape[1],
+                             height=rail_map.shape[0], transitions=transitions)
+    rail.grid = rail_map
+    return rail, rail_map
+
 
 def make_invalid_simple_rail() -> Tuple[GridTransitionMap, np.array]:
     # We instantiate a very simple rail network on a 7x10 grid:
diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py
index a414231619cfa924c2d33776f9f140cf88280517..8812c847e61d81f6614f37d26489b4c17ea7fd14 100644
--- a/tests/test_flatland_core_transition_map.py
+++ b/tests/test_flatland_core_transition_map.py
@@ -1,6 +1,13 @@
 from flatland.core.grid.grid4 import Grid4Transitions, Grid4TransitionsEnum
 from flatland.core.grid.grid8 import Grid8Transitions, Grid8TransitionsEnum
 from flatland.core.transition_map import GridTransitionMap
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import rail_from_grid_transition_map
+from flatland.envs.schedule_generators import random_schedule_generator
+from flatland.utils.rendertools import RenderTool
+from flatland.utils.simple_rail import make_simple_rail, make_simple_rail_unconnected
 
 
 def test_grid4_get_transitions():
@@ -43,4 +50,111 @@ def test_grid8_set_transitions():
     grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 0)
     assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (0, 0, 0, 0, 0, 0, 0, 0)
 
-# TODO GridTransitionMap
+
+def check_path(env, rail, position, direction, target, expected, rendering=False):
+    agent = env.agents_static[0]
+    agent.position = position  # south dead-end
+    agent.direction = direction  # north
+    agent.target = target  # east dead-end
+    agent.moving = True
+    # reset to set agents from agents_static
+    # env.reset(False, False)
+    if rendering:
+        renderer = RenderTool(env, gl="PILSVG")
+        renderer.render_env(show=True, show_observations=False)
+        input("Continue?")
+    assert rail.check_path_exists(agent.position, agent.direction, agent.target) == expected
+
+
+def test_path_exists(rendering=False):
+    rail, rail_map = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
+                  number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  )
+
+    # reset to initialize agents_static
+    env.reset()
+
+    check_path(
+        env,
+        rail,
+        (5, 6),  # north of south dead-end
+        0,  # north
+        (3, 9),  # east dead-end
+        True
+    )
+
+    check_path(
+        env,
+        rail,
+        (6, 6),  # south dead-end
+        2,  # south
+        (3, 9),  # east dead-end
+        True
+    )
+
+    check_path(
+        env,
+        rail,
+        (3, 0),  # east dead-end
+        3,  # west
+        (0, 3),  # north dead-end
+        True
+    )
+    check_path(
+        env,
+        rail,
+        (5, 6),  # east dead-end
+        0,  # west
+        (1, 3),  # north dead-end
+        True)
+
+    check_path(
+        env,
+        rail,
+        (1,3),  # east dead-end
+        2,  # south
+        (3,3),  # north dead-end
+        True
+    )
+
+    check_path(
+        env,
+        rail,
+        (1,3),  # east dead-end
+        0,  # north
+        (3,3),  # north dead-end
+        True
+    )
+
+
+def test_path_not_exists(rendering=False):
+    rail, rail_map = make_simple_rail_unconnected()
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
+                  number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  )
+
+    # reset to initialize agents_static
+    env.reset()
+
+    check_path(
+        env,
+        rail,
+        (5, 6),  # south dead-end
+        0,  # north
+        (0, 3),  # north dead-end
+        False
+    )
+
+    if rendering:
+        renderer = RenderTool(env, gl="PILSVG")
+        renderer.render_env(show=True, show_observations=False)
+        input("Continue?")