diff --git a/examples/temporary_example.py b/examples/temporary_example.py
index a5c97660e4d22ebaffeef27b320e45b21f29dfe5..6bed439cb21f611353763da36a890739c77e5866 100644
--- a/examples/temporary_example.py
+++ b/examples/temporary_example.py
@@ -3,6 +3,7 @@ import numpy as np
 import matplotlib.pyplot as plt
 
 from flatland.envs.rail_env import *
+from flatland.envs.generators import *
 from flatland.core.env_observation_builder import TreeObsForRailEnv
 from flatland.utils.rendertools import *
 
@@ -62,6 +63,11 @@ specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
          [(0, 0), (0, 0), (0, 0), (2, 180), (2, 90), (7, 90)],
          [(0, 0), (0, 0), (0, 0), (7, 180), (0, 0), (0, 0)]]
 
+# CURVED RAIL + DEAD-ENDS TEST
+specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)],
+         [(7, 270), (1, 90), (1, 90), (8, 90), (0, 0), (0, 0)],
+         [(0, 0),   (7, 270),(1, 90), (8, 180), (0, 00), (0, 0)]]
+
 env = RailEnv(width=6,
               height=4,
               rail_generator=rail_from_manual_specifications_generator(specs),
diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py
index fd42787a68c9ee007368fb3ced1ff5162c9926eb..2e34b4236943b36b167268828e3010e727108760 100644
--- a/flatland/core/transition_map.py
+++ b/flatland/core/transition_map.py
@@ -314,14 +314,13 @@ class GridTransitionMap(TransitionMap):
         Checks that:
         - surrounding cells have inbound transitions for all the
             outbound transitions of this cell.
-        
+
         These are NOT checked - see transition.is_valid:
         - all transitions have the mirror transitions (N->E <=> W->S)
         - Reverse transitions (N -> S) only exist for a dead-end
         - a cell contains either no dead-ends or exactly one
 
         Returns: True (valid) or False (invalid)
-        
         """
         cell_transition = self.grid[tuple(rcPos)]
 
diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py
index a862c6a81dbd71a0aa95f2305158f940e964c0ad..622d900598bba6bbc48750d6bb48923975af9b5e 100644
--- a/flatland/core/transitions.py
+++ b/flatland/core/transitions.py
@@ -551,7 +551,7 @@ class RailEnvTransitions(Grid4Transitions):
         super(RailEnvTransitions, self).__init__(
             transitions=self.transition_list
         )
-        
+
         # These bits represent all the possible dead ends
         self.maskDeadEnds = 0b0010000110000100
 
@@ -569,10 +569,10 @@ class RailEnvTransitions(Grid4Transitions):
 
     def print(self, cell_transition):
         print("  NESW")
-        print("N", format(cell_transition >> (3*4) & 0xF, '04b'))
-        print("E", format(cell_transition >> (2*4) & 0xF, '04b'))
-        print("S", format(cell_transition >> (1*4) & 0xF, '04b'))
-        print("W", format(cell_transition >> (0*4) & 0xF, '04b'))
+        print("N", format(cell_transition >> (3 * 4) & 0xF, '04b'))
+        print("E", format(cell_transition >> (2 * 4) & 0xF, '04b'))
+        print("S", format(cell_transition >> (1 * 4) & 0xF, '04b'))
+        print("W", format(cell_transition >> (0 * 4) & 0xF, '04b'))
 
     def repr(self, cell_transition, version=0):
         """
@@ -585,25 +585,23 @@ class RailEnvTransitions(Grid4Transitions):
         sbinTrans = format(cell_transition, "#018b")[2:]
         if version == 0:
             sRepr = " ".join([
-                "{}:{}".format(sDir, sbinTrans[i:i+4])
+                "{}:{}".format(sDir, sbinTrans[i:(i + 4)])
                 for i, sDir in
                     zip(
                         range(0, len(sbinTrans), 4),
-                        self.lsDirs  # NESW
-                    )])
+                        self.lsDirs)])  # NESW
             return sRepr
 
         if version == 1:
             lsRepr = []
             for iDirIn in range(0, 4):
-                sDirTrans = sbinTrans[iDirIn*4:iDirIn*4+4]
+                sDirTrans = sbinTrans[(iDirIn * 4):(iDirIn * 4 + 4)]
                 if sDirTrans == "0000":
                     continue
                 sDirsOut = [
                     self.lsDirs[iDirOut]
                     for iDirOut in range(0, 4)
-                    if sDirTrans[iDirOut] == "1"
-                    ]
+                    if sDirTrans[iDirOut] == "1"]
                 lsRepr.append(self.lsDirs[iDirIn] + ":" + "".join(sDirsOut))
 
             return ", ".join(lsRepr)
diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py
index 79f5eb0d23b8b8a50bc1afd9ba1d34bba76c4ffd..382bdd7d9c483aa491f9c11b06b62b13b45490a5 100644
--- a/flatland/envs/env_utils.py
+++ b/flatland/envs/env_utils.py
@@ -271,4 +271,3 @@ def connect_rail(rail_trans, rail_array, start, end):
 
 def distance_on_rail(pos1, pos2):
     return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
-
diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index baba6baac4a4efe76844147d7c49435a88ad79af..021c63f31a958790c04e17f49c7a1da57777cb9c 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -250,7 +250,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
 
         transitions_templates_ = []
         transition_probabilities = []
-        for i in range(len(t_utils.transitions)-4):  # don't include dead-ends
+        for i in range(len(t_utils.transitions) - 4):  # don't include dead-ends
             all_transitions = 0
             for dir_ in range(4):
                 trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
@@ -475,4 +475,3 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
         return return_rail
 
     return generator
-
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 9a48915bce8386f0ffe0237c6bf9512b0609d82a..cbb11c2db70dece00199bfae6bb051504f64a9f2 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -268,7 +268,7 @@ class RailEnv(Environment):
                         break
                     else:
                         self.agents_direction[i] = direction
-                
+
         # Reset the state of the observation builder with the new environment
         self.obs_builder.reset()
 
@@ -314,60 +314,33 @@ class RailEnv(Environment):
                 # compute number of possible transitions in the current
                 # cell used to check for invalid actions
 
-                nbits = 0
-                tmp = self.rail.get_transitions((pos[0], pos[1]))
                 possible_transitions = self.rail.get_transitions((pos[0], pos[1], direction))
-                # print(np.sum(self.rail.get_transitions((pos[0], pos[1],direction))),
-                # self.rail.get_transitions((pos[0], pos[1],direction)),
-                # self.rail.get_transitions((pos[0], pos[1])),
-                # (pos[0], pos[1],direction))
+                num_transitions = np.count_nonzero(possible_transitions)
 
-                while tmp > 0:
-                    nbits += (tmp & 1)
-                    tmp = tmp >> 1
                 movement = direction
                 # print(nbits,np.sum(possible_transitions))
                 if action == 1:
                     movement = direction - 1
-                    if nbits <= 2 or np.sum(possible_transitions) <= 1:
+                    if num_transitions <= 1:
                         transition_isValid = False
 
                 elif action == 3:
                     movement = direction + 1
-                    if nbits <= 2 or np.sum(possible_transitions) <= 1:
+                    if num_transitions <= 1:
                         transition_isValid = False
+
                 if movement < 0:
                     movement += 4
                 if movement >= 4:
                     movement -= 4
 
-                is_deadend = False
                 if action == 2:
-                    if nbits == 1:
-                        # dead-end;  assuming the rail network is consistent,
-                        # this should match the direction the agent has come
-                        # from. But it's better to check in any case.
-                        reverse_direction = 0
-                        if direction == 0:
-                            reverse_direction = 2
-                        elif direction == 1:
-                            reverse_direction = 3
-                        elif direction == 2:
-                            reverse_direction = 0
-                        elif direction == 3:
-                            reverse_direction = 1
-
-                        valid_transition = self.rail.get_transition(
-                            (pos[0], pos[1], direction),
-                            reverse_direction)
-                        if valid_transition:
-                            direction = reverse_direction
-                            movement = reverse_direction
-                            is_deadend = True
-
-                    if np.sum(possible_transitions) == 1:
-                        # Take only available transition
+                    if num_transitions == 1:
+                        # - dead-end, straight line or curved line;
+                        # movement will be the only valid transition
+                        # - take only available transition
                         movement = np.argmax(possible_transitions)
+                        transition_isValid = True
 
                 new_position = self._new_position(pos, movement)
                 # Is it a legal move?  1) transition allows the movement in the
@@ -388,7 +361,7 @@ class RailEnv(Environment):
                 if transition_isValid is None:
                     transition_isValid = self.rail.get_transition(
                         (pos[0], pos[1], direction),
-                        movement) or is_deadend
+                        movement)
 
                 cell_isFree = True
                 for j in range(self.number_of_agents):
diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py
index 21c1e64b8ce23b70d21dae032eb75f410dbddcca..543b793c5f7a0e0a6a9e07c46902735ec21c7158 100644
--- a/flatland/utils/editor.py
+++ b/flatland/utils/editor.py
@@ -31,10 +31,10 @@ class EditorMVC(object):
 
         if env is None:
             env = RailEnv(width=10,
-              height=10,
-              rail_generator=random_rail_generator(cell_type_relative_proportion=[1, 1] + [0.5] * 6),
-              number_of_agents=0,
-              obs_builder_object=TreeObsForRailEnv(max_depth=2))
+                          height=10,
+                          rail_generator=random_rail_generator(cell_type_relative_proportion=[1, 1] + [0.5] * 6),
+                          number_of_agents=0,
+                          obs_builder_object=TreeObsForRailEnv(max_depth=2))
 
         env.reset()
 
@@ -113,20 +113,19 @@ class View(object):
             wButton = ipywidgets.Button(description=dButton["name"])
             wButton.on_click(dButton["method"])
             self.lwButtons.append(wButton)
-            
+
         self.wVbox_controls = VBox([
             self.wFilename,  # self.wDrawMode,
             *self.lwButtons,
             self.wSize,
             self.wDebug, self.wDebug_move,
-            self.wProg_steps
-            ])
-        
+            self.wProg_steps])
+
         self.wMain = HBox([self.wImage, self.wVbox_controls])
 
     def drawStroke(self):
         pass
-    
+
     def new_env(self):
         self.oRT = rt.RenderTool(self.editor.env)
 
@@ -139,11 +138,11 @@ class View(object):
             img = self.oRT.getImage()
             plt.clf()
             plt.close()
-        
+
             self.wImage.data = img
             self.writableData = np.copy(self.wImage.data)
             return img
-    
+
     def redisplayImage(self):
         if self.writableData is not None:
             # This updates the image in the browser to be the new edited version
@@ -152,7 +151,7 @@ class View(object):
     def drag_path_element(self, x, y):
         # Draw a black square on the in-memory copy of the image
         if x > 10 and x < self.yxSize[1] and y > 10 and y < self.yxSize[0]:
-            self.writableData[y-2:y+2, x-2:x+2, :] = 0
+            self.writableData[(y - 2):(y + 2), (x - 2):(x + 2), :] = 0
 
     def xy_to_rc(self, x, y):
         rcCell = ((array([y, x]) - self.yxBase) / self.nPixCell).astype(int)
@@ -227,7 +226,7 @@ class Controller(object):
         # The intention was to avoid too many redraws.
         if event["buttons"] > 0:
             qEvents.append((time.time(), x, y))
-        
+
         # Process the events in our queue:
         # Draw a black square to indicate a trail
         # TODO: infer a vector of moves between these squares to avoid gaps
@@ -238,13 +237,13 @@ class Controller(object):
             if tNow - qEvents[0][0] > 0.1:   # wait before trying to draw
                 # height, width = wid.data.shape[:2]
                 # writableData = np.copy(self.wid_img.data)  # writable copy of image - wid_img.data is somehow readonly
-                
+
                 # with self.wid_img.hold_sync():
-                
+
                 while len(qEvents) > 0:
                     t, x, y = qEvents.popleft()  # get events from our queue
                     self.view.drag_path_element(x, y)
-                    
+
                     # Translate and scale from x,y to integer row,col (note order change)
                     # rcCell = ((array([y, x]) - self.yxBase) / self.nPixCell).astype(int)
                     rcCell = self.view.xy_to_rc(x, y)
@@ -266,13 +265,13 @@ class Controller(object):
     def refresh(self, event):
         self.debug("refresh")
         self.view.redraw()
-    
+
     def clear(self, event):
         self.model.clear()
 
     def regenerate(self, event):
         self.model.regenerate()
-    
+
     def setRegenSize(self, event):
         self.model.setRegenSize(event["new"])
 
@@ -342,14 +341,14 @@ class EditorModel(object):
         """Mouse motion event handler for drawing.
         """
         lrcStroke = self.lrcStroke
-                        
+
         # Store the row,col location of the click, if we have entered a new cell
         if len(lrcStroke) > 0:
             rcLast = lrcStroke[-1]
             if not np.array_equal(rcLast, rcCell):  # only save at transition
                 lrcStroke.append(rcCell)
                 self.debug("lrcStroke ", len(lrcStroke), rcCell)
-                
+
         else:
             # This is the first cell in a mouse stroke
             lrcStroke.append(rcCell)
@@ -427,14 +426,16 @@ class EditorModel(object):
             # Set the transition
             # If this transition spans 3 cells, it is not a deadend, so remove any deadends.
             # The user will need to resolve any conflicts.
-            self.env.rail.set_transition((*rcMiddle, liTrans[0]), liTrans[1], bAddRemove,
-                remove_deadends=not bDeadend)
+            self.env.rail.set_transition((*rcMiddle, liTrans[0]),
+                                         liTrans[1],
+                                         bAddRemove,
+                                         remove_deadends=not bDeadend)
 
             # Also set the reverse transition
             # use the reversed outbound transition for inbound
             # and the reversed inbound transition for outbound
             self.env.rail.set_transition((*rcMiddle, mirror(liTrans[1])),
-                mirror(liTrans[0]), bAddRemove, remove_deadends=not bDeadend)
+                                         mirror(liTrans[0]), bAddRemove, remove_deadends=not bDeadend)
 
             # bValid = self.env.rail.is_cell_valid(rcMiddle)
             # if not bValid:
@@ -457,7 +458,7 @@ class EditorModel(object):
 
         # get the row, col delta between the 2 cells, eg [-1,0] = North
         rc1Trans = np.diff(rc2Cells, axis=0)
-        
+
         # get the direction index for the transition
         liTrans = []
         for rcTrans in rc1Trans:
@@ -491,7 +492,7 @@ class EditorModel(object):
         self.env.agents_handles = []
         self.env.agents_target = []
         self.player = None
-        
+
         self.redraw()
 
     def setFilename(self, filename):
@@ -507,7 +508,7 @@ class EditorModel(object):
             self.redraw()
         else:
             self.log("File does not exist:", self.env_filename, " Working directory: ", os.getcwd())
-    
+
     def save(self):
         self.log("save to ", self.env_filename, " working dir: ", os.getcwd())
         self.env.rail.save_transition_map(self.env_filename)
@@ -515,17 +516,17 @@ class EditorModel(object):
     def regenerate(self):
         self.log("Regenerate size", self.regen_size)
         self.env = RailEnv(width=self.regen_size,
-              height=self.regen_size,
-              rail_generator=random_rail_generator(cell_type_relative_proportion=[1, 1] + [0.5] * 6),
-              number_of_agents=self.env.number_of_agents,
-              obs_builder_object=TreeObsForRailEnv(max_depth=2))
+                           height=self.regen_size,
+                           rail_generator=random_rail_generator(cell_type_relative_proportion=[1, 1] + [0.5] * 6),
+                           number_of_agents=self.env.number_of_agents,
+                           obs_builder_object=TreeObsForRailEnv(max_depth=2))
         self.env.reset(regen_rail=True)
         self.fix_env()
         self.set_env(self.env)
         self.player = Player(self.env)
         self.view.new_env()
         self.redraw()
-        
+
     def setRegenSize(self, size):
         self.regen_size = size
 
@@ -579,5 +580,9 @@ class EditorModel(object):
     def debug_cell(self, rcCell):
         binTrans = self.env.rail.get_transitions(rcCell)
         sbinTrans = format(binTrans, "#018b")[2:]
-        self.debug("cell ", rcCell, "Transitions: ", binTrans, sbinTrans,
-            [sbinTrans[i:i+4] for i in range(0, len(sbinTrans), 4)])
+        self.debug("cell ",
+                   rcCell,
+                   "Transitions: ",
+                   binTrans,
+                   sbinTrans,
+                   [sbinTrans[i:(i + 4)] for i in range(0, len(sbinTrans), 4)])
diff --git a/flatland/utils/render_qt.py b/flatland/utils/render_qt.py
index 05c436dc9f8723abdf2c013b9fad5181e695b58f..0804c9dd27b191b54f796b180165e1d73180ec16 100644
--- a/flatland/utils/render_qt.py
+++ b/flatland/utils/render_qt.py
@@ -61,8 +61,8 @@ class QTGL(GraphicsLayer):
                 if lastx is not None:
                     # print("line", lastx, lasty, x, y)
                     self.qtr.drawLine(
-                        lastx*self.cell_pixels, -lasty*self.cell_pixels,
-                        x*self.cell_pixels, -y*self.cell_pixels)
+                        lastx * self.cell_pixels, -lasty * self.cell_pixels,
+                        x * self.cell_pixels, -y * self.cell_pixels)
                 lastx = x
                 lasty = y
         else:
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 09ac48997bcb1c420be146ee011042c45b6b597f..e84ff46bb7fa7590bc043a3bc9e3c109670a8879 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -145,7 +145,7 @@ class RenderTool(object):
             self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color)
 
     def plotAgents(self, targets=True):
-        cmap = self.gl.get_cmap('hsv', lut=self.env.number_of_agents+1)
+        cmap = self.gl.get_cmap('hsv', lut=self.env.number_of_agents + 1)
         for iAgent in range(self.env.number_of_agents):
             oColor = cmap(iAgent)
 
@@ -208,16 +208,16 @@ class RenderTool(object):
         xyDir = np.matmul(rcDir, rt.grc2xy)          # agent direction in xy
 
         xyPos = np.matmul(rcPos - rcDir / 2, rt.grc2xy) + rt.xyHalf
-        #print("Agent:", rcPos, iDir, rcDir, xyDir, xyPos)
+        # print("Agent:", rcPos, iDir, rcDir, xyDir, xyPos)
         self.gl.scatter(*xyPos, color=color, marker="o", s=100)            # agent location
 
-        xyDirLine = array([xyPos, xyPos + xyDir/2]).T  # line for agent orient.
+        xyDirLine = array([xyPos, xyPos + xyDir / 2]).T  # line for agent orient.
         self.gl.plot(*xyDirLine, color=color, lw=5, ms=0, alpha=0.6)
 
         if target is not None:
             rcTarget = array(target)
             xyTarget = np.matmul(rcTarget, rt.grc2xy) + rt.xyHalf
-            self._draw_square(xyTarget, 1/3, color)
+            self._draw_square(xyTarget, 1 / 3, color)
 
     def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None):
         """
@@ -587,7 +587,7 @@ class RenderTool(object):
 
         # Draw each agent + its orientation + its target
         if agents:
-            cmap = self.gl.get_cmap('hsv', lut=env.number_of_agents+1)
+            cmap = self.gl.get_cmap('hsv', lut=env.number_of_agents + 1)
             self.plotAgents(targets=True)
 
         if False:
@@ -659,4 +659,3 @@ class RenderTool(object):
 
     def getImage(self):
         return self.gl.getImage()
-