diff --git a/examples/training_navigation.py b/examples/training_navigation.py
index 554782e3b03803e450a4085098f96bfa4192d84f..b1032511ec00cbf23a4cfe7b8bca4bca370f5180 100644
--- a/examples/training_navigation.py
+++ b/examples/training_navigation.py
@@ -40,7 +40,7 @@ scores = []
 dones_list = []
 action_prob = [0]*4
 agent = Agent(state_size, action_size, "FC", 0)
-#agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint8000.pth'))
+agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth'))
 def max_lt(seq, val):
     """
     Return greatest item in seq for which item < val applies.
@@ -70,11 +70,11 @@ for trials in range(1, n_trials + 1):
     # Run episode
     for step in range(50):
         #if trials > 114:
-        #env_renderer.renderEnv(show=True)
+        env_renderer.renderEnv(show=True)
         #print(step)
         # Action
         for a in range(env.number_of_agents):
-            action = agent.act(np.array(obs[a]), eps=eps)
+            action = agent.act(np.array(obs[a]), eps=0)
             action_prob[action] += 1
             action_dict.update({a: action})
 
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index 86485ec2c068ea410d5d27997f6f037d3aab6c23..8737862b60d9c330cb95e7679b3e39c2ad897da6 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -114,7 +114,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                     nodes_queue.append(n)
 
                 if len(valid_neighbors) > 0:
-                    max_distance = max(max_distance, node[3]+1)
+                    max_distance = max(max_distance, node[3] + 1)
 
         return max_distance
 
@@ -129,7 +129,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         if enforce_target_direction >= 0:
             # The agent must land into the current cell with orientation `enforce_target_direction'.
             # This is only possible if the agent has arrived from the cell in the opposite direction!
-            possible_directions = [(enforce_target_direction+2) % 4]
+            possible_directions = [(enforce_target_direction + 2) % 4]
 
         for neigh_direction in possible_directions:
             new_cell = self._new_position(position, neigh_direction)
@@ -137,7 +137,7 @@ class TreeObsForRailEnv(ObservationBuilder):
             if new_cell[0] >= 0 and new_cell[0] < self.env.height and \
                new_cell[1] >= 0 and new_cell[1] < self.env.width:
 
-                desired_movement_from_new_cell = (neigh_direction+2) % 4
+                desired_movement_from_new_cell = (neigh_direction + 2) % 4
 
                 """
                 # Is the next cell a dead-end?
@@ -166,7 +166,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                             movement = (desired_movement_from_new_cell+2) % 4
                         """
                         new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
-                                           current_distance+1)
+                                           current_distance + 1)
                         neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
                         self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
 
@@ -177,11 +177,11 @@ class TreeObsForRailEnv(ObservationBuilder):
         Utility function that converts a compass movement over a 2D grid to new positions (r, c).
         """
         if movement == 0:    # NORTH
-            return (position[0]-1, position[1])
+            return (position[0] - 1, position[1])
         elif movement == 1:  # EAST
             return (position[0], position[1] + 1)
         elif movement == 2:  # SOUTH
-            return (position[0]+1, position[1])
+            return (position[0] + 1, position[1])
         elif movement == 3:  # WEST
             return (position[0], position[1] - 1)
 
@@ -241,7 +241,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         # Start from the current orientation, and see which transitions are available;
         # organize them as [left, forward, right, back], relative to the current orientation
-        for branch_direction in [(orientation+4+i) % 4 for i in range(-1, 3)]:
+        for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]:
             if self.env.rail.get_transition((position[0], position[1], orientation), branch_direction):
                 new_cell = self._new_position(position, branch_direction)
 
@@ -253,7 +253,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                 for i in range(self.max_depth):
                     num_cells_to_fill_in += pow4
                     pow4 *= 4
-                observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf]*num_cells_to_fill_in
+                observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
 
         return observation
 
@@ -262,7 +262,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         Utility function to compute tree-based observations.
         """
         # [Recursive branch opened]
-        if depth >= self.max_depth+1:
+        if depth >= self.max_depth + 1:
             return []
 
         # Continue along direction until next switch or
@@ -356,7 +356,7 @@ class TreeObsForRailEnv(ObservationBuilder):
             observation = [0,
                            1 if other_target_encountered else 0,
                            1 if other_agent_encountered else 0,
-                           root_observation[3]+num_steps,
+                           root_observation[3] + num_steps,
                            0]
 
         elif last_isTerminal:
@@ -369,7 +369,7 @@ class TreeObsForRailEnv(ObservationBuilder):
             observation = [0,
                            1 if other_target_encountered else 0,
                            1 if other_agent_encountered else 0,
-                           root_observation[3]+num_steps,
+                           root_observation[3] + num_steps,
                            self.distance_map[handle, position[0], position[1], direction]]
 
         # #############################
@@ -379,18 +379,18 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         # Start from the current orientation, and see which transitions are available;
         # organize them as [left, forward, right, back], relative to the current orientation
-        for branch_direction in [(direction+4+i) % 4 for i in range(-1, 3)]:
+        for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
             if last_isDeadEnd and self.env.rail.get_transition((position[0], position[1], direction),
-                                                               (branch_direction+2) % 4):
+                                                               (branch_direction + 2) % 4):
                 # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
                 # it back
-                new_cell = self._new_position(position, (branch_direction+2) % 4)
+                new_cell = self._new_position(position, (branch_direction + 2) % 4)
 
                 branch_observation = self._explore_branch(handle,
                                                           new_cell,
-                                                          (branch_direction+2) % 4,
+                                                          (branch_direction + 2) % 4,
                                                           new_root_observation,
-                                                          depth+1)
+                                                          depth + 1)
                 observation = observation + branch_observation
 
             elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction),
@@ -401,16 +401,16 @@ class TreeObsForRailEnv(ObservationBuilder):
                                                           new_cell,
                                                           branch_direction,
                                                           new_root_observation,
-                                                          depth+1)
+                                                          depth + 1)
                 observation = observation + branch_observation
 
             else:
                 num_cells_to_fill_in = 0
                 pow4 = 1
-                for i in range(self.max_depth-depth):
+                for i in range(self.max_depth - depth):
                     num_cells_to_fill_in += pow4
                     pow4 *= 4
-                observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf]*num_cells_to_fill_in
+                observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
 
         return observation
 
@@ -422,7 +422,7 @@ class TreeObsForRailEnv(ObservationBuilder):
             return
 
         depth = 0
-        tmp = len(tree)/num_features_per_node-1
+        tmp = len(tree) / num_features_per_node - 1
         pow4 = 4
         while tmp > 0:
             tmp -= pow4
@@ -431,15 +431,15 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         prompt_ = ['L:', 'F:', 'R:', 'B:']
 
-        print("  "*current_depth + prompt, tree[0:num_features_per_node])
-        child_size = (len(tree)-num_features_per_node)//4
+        print("  " * current_depth + prompt, tree[0:num_features_per_node])
+        child_size = (len(tree) - num_features_per_node) // 4
         for children in range(4):
-            child_tree = tree[(num_features_per_node+children*child_size):
-                              (num_features_per_node+(children+1)*child_size)]
+            child_tree = tree[(num_features_per_node + children * child_size):
+                              (num_features_per_node + (children + 1) * child_size)]
             self.util_print_obs_subtree(child_tree,
                                         num_features_per_node,
                                         prompt=prompt_[children],
-                                        current_depth=current_depth+1)
+                                        current_depth=current_depth + 1)
 
 
 class GlobalObsForRailEnv(ObservationBuilder):
diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py
index eb4cb8e394c2effd6e5ba1bfcedf381281ea9388..a8cb8d6f49157bafbb65551b53c6612c45565c88 100644
--- a/flatland/core/transitions.py
+++ b/flatland/core/transitions.py
@@ -180,7 +180,7 @@ class Grid4Transitions(Transitions):
             List of the validity of transitions in the cell.
 
         """
-        bits = (cell_transition >> ((3-orientation)*4))
+        bits = (cell_transition >> ((3 - orientation) * 4))
         return ((bits >> 3) & 1, (bits >> 2) & 1, (bits >> 1) & 1, (bits) & 1)
 
     def set_transitions(self, cell_transition, orientation, new_transitions):
@@ -208,7 +208,7 @@ class Grid4Transitions(Transitions):
             `orientation'.
 
         """
-        mask = (1 << ((4-orientation)*4)) - (1 << ((3-orientation)*4))
+        mask = (1 << ((4 - orientation) * 4)) - (1 << ((3 - orientation) * 4))
         negmask = ~mask
 
         new_transitions = \
@@ -217,9 +217,7 @@ class Grid4Transitions(Transitions):
             (new_transitions[2] & 1) << 1 | \
             (new_transitions[3] & 1)
 
-        cell_transition = \
-            (cell_transition & negmask) | \
-            (new_transitions << ((3-orientation)*4))
+        cell_transition = (cell_transition & negmask) | (new_transitions << ((3 - orientation) * 4))
 
         return cell_transition
 
@@ -245,8 +243,7 @@ class Grid4Transitions(Transitions):
             Validity of the requested transition: 0/1 allowed/not allowed.
 
         """
-        return ((cell_transition >> ((4-1-orientation) * 4)) >>
-                (4-1-direction)) & 1
+        return ((cell_transition >> ((4 - 1 - orientation) * 4)) >> (4 - 1 - direction)) & 1
 
     def set_transition(self, cell_transition, orientation, direction,
                        new_transition):
@@ -276,12 +273,9 @@ class Grid4Transitions(Transitions):
 
         """
         if new_transition:
-            cell_transition |= (1 << ((4-1-orientation) * 4 +
-                                (4-1-direction)))
+            cell_transition |= (1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction)))
         else:
-            cell_transition &= \
-                ~(1 << ((4-1-orientation) * 4 +
-                        (4-1-direction)))
+            cell_transition &= ~(1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction)))
 
         return cell_transition
 
@@ -310,13 +304,11 @@ class Grid4Transitions(Transitions):
         rotation = rotation // 90
         for i in range(4):
             block_tuple = self.get_transitions(value, i)
-            block_tuple = block_tuple[(
-                4-rotation):] + block_tuple[:(4-rotation)]
+            block_tuple = block_tuple[(4 - rotation):] + block_tuple[:(4 - rotation)]
             value = self.set_transitions(value, i, block_tuple)
 
         # Rotate the 4-bits blocks
-        value = ((value & (2**(rotation*4)-1)) <<
-                 ((4-rotation)*4)) | (value >> (rotation*4))
+        value = ((value & (2**(rotation * 4) - 1)) << ((4 - rotation) * 4)) | (value >> (rotation * 4))
 
         cell_transition = value
         return cell_transition
@@ -355,7 +347,7 @@ class Grid8Transitions(Transitions):
             List of the validity of transitions in the cell.
 
         """
-        bits = (cell_transition >> ((7-orientation)*8))
+        bits = (cell_transition >> ((7 - orientation) * 8))
         cell_transition = (
             (bits >> 7) & 1,
             (bits >> 6) & 1,
@@ -389,7 +381,7 @@ class Grid8Transitions(Transitions):
             `orientation'.
 
         """
-        mask = (1 << ((8-orientation)*8)) - (1 << ((7-orientation)*8))
+        mask = (1 << ((8 - orientation) * 8)) - (1 << ((7 - orientation) * 8))
         negmask = ~mask
 
         new_transitions = \
@@ -402,8 +394,7 @@ class Grid8Transitions(Transitions):
             (new_transitions[6] & 1) << 1 | \
             (new_transitions[7] & 1)
 
-        cell_transition = (cell_transition & negmask) | (
-            new_transitions << ((7-orientation)*8))
+        cell_transition = (cell_transition & negmask) | (new_transitions << ((7 - orientation) * 8))
 
         return cell_transition
 
@@ -429,8 +420,7 @@ class Grid8Transitions(Transitions):
             Validity of the requested transition: 0/1 allowed/not allowed.
 
         """
-        return ((cell_transition >> ((8-1-orientation) * 8)) >>
-                (8-1-direction)) & 1
+        return ((cell_transition >> ((8 - 1 - orientation) * 8)) >> (8 - 1 - direction)) & 1
 
     def set_transition(self, cell_transition, orientation, direction,
                        new_transition):
@@ -460,11 +450,9 @@ class Grid8Transitions(Transitions):
 
         """
         if new_transition:
-            cell_transition |= (1 << ((8-1-orientation) * 8 +
-                                (8 - 1 - direction)))
+            cell_transition |= (1 << ((8 - 1 - orientation) * 8 + (8 - 1 - direction)))
         else:
-            cell_transition &= ~(1 << ((8-1-orientation) * 8 +
-                                 (8 - 1 - direction)))
+            cell_transition &= ~(1 << ((8 - 1 - orientation) * 8 + (8 - 1 - direction)))
 
         return cell_transition
 
@@ -500,8 +488,7 @@ class Grid8Transitions(Transitions):
             value = self.set_transitions(value, i, block_tuple)
 
         # Rotate the 8bits blocks
-        value = ((value & (2**(rotation*8)-1)) <<
-                 ((8-rotation)*8)) | (value >> (rotation*8))
+        value = ((value & (2**(rotation * 8) - 1)) << ((8 - rotation) * 8)) | (value >> (rotation * 8))
 
         cell_transition = value
 
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index cf206033d7aab65571b18de0c95d729f5d09c65c..3fadf66ccf185329c49d2ea105728e3faedf0ae5 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -45,8 +45,7 @@ def rail_from_manual_specifications_generator(rail_spec):
                 if cell[0] < 0 or cell[0] >= len(t_utils.transitions):
                     print("ERROR - invalid cell type=", cell[0])
                     return []
-                rail.set_transitions((r, c), t_utils.rotate_transition(
-                              t_utils.transitions[cell[0]], cell[1]))
+                rail.set_transitions((r, c), t_utils.rotate_transition(t_utils.transitions[cell[0]], cell[1]))
 
         return rail
 
@@ -110,7 +109,7 @@ def generate_rail_from_list_of_manual_specifications(list_of_specifications)
 """
 
 
-def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
+def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
     """
     Dummy random level generator:
     - fill in cells at random in [width-2, height-2]
@@ -149,7 +148,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
 
         transitions_templates_ = []
         transition_probabilities = []
-        for i in range(len(t_utils.transitions)-1):  # don't include dead-ends
+        for i in range(len(t_utils.transitions) - 1):  # don't include dead-ends
             all_transitions = 0
             for dir_ in range(4):
                 trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
@@ -159,7 +158,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
                                    (trans[3])
 
             template = [int(x) for x in bin(all_transitions)[2:]]
-            template = [0]*(4-len(template)) + template
+            template = [0] * (4 - len(template)) + template
 
             # add all rotations
             for rot in [0, 90, 180, 270]:
@@ -168,7 +167,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
                                                t_utils.transitions[i],
                                                rot)))
                 transition_probabilities.append(transition_probability[i])
-                template = [template[-1]]+template[:-1]
+                template = [template[-1]] + template[:-1]
 
         def get_matching_templates(template):
             ret = []
@@ -183,7 +182,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
                     ret.append((transitions_templates_[i][1], transition_probabilities[i]))
             return ret
 
-        MAX_INSERTIONS = (width-2) * (height-2) * 10
+        MAX_INSERTIONS = (width - 2) * (height - 2) * 10
         MAX_ATTEMPTS_FROM_SCRATCH = 10
 
         attempt_number = 0
@@ -191,10 +190,9 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
             cells_to_fill = []
             rail = []
             for r in range(height):
-                rail.append([None]*width)
-                if r > 0 and r < height-1:
-                    cells_to_fill = cells_to_fill \
-                                    + [(r, c) for c in range(1, width-1)]
+                rail.append([None] * width)
+                if r > 0 and r < height - 1:
+                    cells_to_fill = cells_to_fill + [(r, c) for c in range(1, width - 1)]
 
             num_insertions = 0
             while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0:
@@ -212,14 +210,13 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
                            (1, 3, (0, 1)),
                            (2, 0, (1, 0)),
                            (3, 1, (0, -1))]:  # N, E, S, W
-                    neigh_trans = rail[row+el[2][0]][col+el[2][1]]
+                    neigh_trans = rail[row + el[2][0]][col + el[2][1]]
                     if neigh_trans is not None:
                         # select transition coming from facing direction el[1] and
                         # moving to direction el[1]
                         max_bit = 0
                         for k in range(4):
-                            max_bit |= \
-                             t_utils.get_transition(neigh_trans, k, el[1])
+                            max_bit |= t_utils.get_transition(neigh_trans, k, el[1])
 
                         if max_bit:
                             valid_template[el[0]] = 1
@@ -244,8 +241,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
                                 elif k == 3:
                                     rot = 90
 
-                                rail[row][col] = t_utils.rotate_transition(
-                                                  int('0010000000000000', 2), rot)
+                                rail[row][col] = t_utils.rotate_transition(int('0010000000000000', 2), rot)
                                 num_insertions += 1
 
                                 break
@@ -258,8 +254,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
                         for k in range(4):
                             tmp_template = valid_template[:]
                             tmp_template[k] = -1
-                            possible_cell_transitions = get_matching_templates(
-                                                         tmp_template)
+                            possible_cell_transitions = get_matching_templates(tmp_template)
                             if len(possible_cell_transitions) > len(besttrans):
                                 besttrans = possible_cell_transitions
                                 bestk = k
@@ -284,7 +279,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
                             rail[replace_row][replace_col] = None
 
                             possible_transitions, possible_probabilities = zip(*besttrans)
-                            possible_probabilities = [p/sum(possible_probabilities) for p in possible_probabilities]
+                            possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities]
 
                             rail[row][col] = np.random.choice(possible_transitions,
                                                               p=possible_probabilities)
@@ -298,7 +293,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
 
                 else:
                     possible_transitions, possible_probabilities = zip(*possible_cell_transitions)
-                    possible_probabilities = [p/sum(possible_probabilities) for p in possible_probabilities]
+                    possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities]
 
                     rail[row][col] = np.random.choice(possible_transitions,
                                                       p=possible_probabilities)
@@ -321,12 +316,10 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
             neigh_trans = rail[r][1]
             if neigh_trans is not None:
                 for k in range(4):
-                    neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
-                                                 & (2**4-1)
+                    neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1)
                     max_bit = max_bit | (neigh_trans_from_direction & 1)
             if max_bit:
-                rail[r][0] = t_utils.rotate_transition(
-                               int('0010000000000000', 2), 270)
+                rail[r][0] = t_utils.rotate_transition(int('0010000000000000', 2), 270)
             else:
                 rail[r][0] = int('0000000000000000', 2)
 
@@ -335,8 +328,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
             neigh_trans = rail[r][-2]
             if neigh_trans is not None:
                 for k in range(4):
-                    neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
-                                                 & (2**4-1)
+                    neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1)
                     max_bit = max_bit | (neigh_trans_from_direction & (1 << 2))
             if max_bit:
                 rail[r][-1] = t_utils.rotate_transition(int('0010000000000000', 2),
@@ -350,8 +342,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
             neigh_trans = rail[1][c]
             if neigh_trans is not None:
                 for k in range(4):
-                    neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
-                                                  & (2**4-1)
+                    neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1)
                     max_bit = max_bit | (neigh_trans_from_direction & (1 << 3))
             if max_bit:
                 rail[0][c] = int('0010000000000000', 2)
@@ -363,12 +354,10 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
             neigh_trans = rail[-2][c]
             if neigh_trans is not None:
                 for k in range(4):
-                    neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
-                                                 & (2**4-1)
+                    neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1)
                     max_bit = max_bit | (neigh_trans_from_direction & (1 << 1))
             if max_bit:
-                rail[-1][c] = t_utils.rotate_transition(
-                                int('0010000000000000', 2), 180)
+                rail[-1][c] = t_utils.rotate_transition(int('0010000000000000', 2), 180)
             else:
                 rail[-1][c] = int('0000000000000000', 2)
 
@@ -458,8 +447,8 @@ class RailEnv(Environment):
         self.obs_builder = obs_builder_object
         self.obs_builder._set_env(self)
 
-        self.actions = [0]*self.number_of_agents
-        self.rewards = [0]*self.number_of_agents
+        self.actions = [0] * self.number_of_agents
+        self.rewards = [0] * self.number_of_agents
         self.done = False
 
         self.dones = {"__all__": False}
@@ -507,14 +496,13 @@ class RailEnv(Environment):
 
             # agents_direction must be a direction for which a solution is
             # guaranteed.
-            self.agents_direction = [0]*self.number_of_agents
+            self.agents_direction = [0] * self.number_of_agents
             re_generate = False
             for i in range(self.number_of_agents):
                 valid_movements = []
                 for direction in range(4):
                     position = self.agents_position[i]
-                    moves = self.rail.get_transitions(
-                            (position[0], position[1], direction))
+                    moves = self.rail.get_transitions((position[0], position[1], direction))
                     for move_index in range(4):
                         if moves[move_index]:
                             valid_movements.append((direction, move_index))
@@ -608,8 +596,8 @@ class RailEnv(Environment):
                             reverse_direction = 1
 
                         valid_transition = self.rail.get_transition(
-                                            (pos[0], pos[1], direction),
-                                            reverse_direction)
+                            (pos[0], pos[1], direction),
+                            reverse_direction)
                         if valid_transition:
                             direction = reverse_direction
                             movement = reverse_direction
@@ -629,8 +617,8 @@ class RailEnv(Environment):
                     new_cell_isValid = False
 
                 transition_isValid = self.rail.get_transition(
-                     (pos[0], pos[1], direction),
-                     movement) or is_deadend
+                    (pos[0], pos[1], direction),
+                    movement) or is_deadend
 
                 cell_isFree = True
                 for j in range(self.number_of_agents):
@@ -663,20 +651,20 @@ class RailEnv(Environment):
 
         if num_agents_in_target_position == self.number_of_agents:
             self.dones["__all__"] = True
-            self.rewards_dict = [r+global_reward for r in self.rewards_dict]
+            self.rewards_dict = [r + global_reward for r in self.rewards_dict]
 
         # Reset the step actions (in case some agent doesn't 'register_action'
         # on the next step)
-        self.actions = [0]*self.number_of_agents
+        self.actions = [0] * self.number_of_agents
         return self._get_observations(), self.rewards_dict, self.dones, {}
 
     def _new_position(self, position, movement):
         if movement == 0:    # NORTH
-            return (position[0]-1, position[1])
+            return (position[0] - 1, position[1])
         elif movement == 1:  # EAST
             return (position[0], position[1] + 1)
         elif movement == 2:  # SOUTH
-            return (position[0]+1, position[1])
+            return (position[0] + 1, position[1])
         elif movement == 3:  # WEST
             return (position[0], position[1] - 1)
 
diff --git a/flatland/utils/graphics_layer.py b/flatland/utils/graphics_layer.py
index b199a3b2e4d71261a402254ccb60ec453e30469c..5ace85cf76514676e9fcac7d6a749051cd404226 100644
--- a/flatland/utils/graphics_layer.py
+++ b/flatland/utils/graphics_layer.py
@@ -18,16 +18,15 @@ class GraphicsLayer(object):
 
     def show(self, block=False):
         pass
-    
+
     def pause(self, seconds=0.00001):
         pass
 
     def clf(self):
         pass
-    
+
     def beginFrame(self):
         pass
-    
+
     def endFrame(self):
         pass
-
diff --git a/flatland/utils/graphics_qt.py b/flatland/utils/graphics_qt.py
index a4abb5789b67cd15a58c4eaa426bcd88ef297800..09dc1fa4d7ea50629a4e1f9f91bb79a549b62ac8 100644
--- a/flatland/utils/graphics_qt.py
+++ b/flatland/utils/graphics_qt.py
@@ -214,13 +214,12 @@ class QtRenderer(object):
     def takeSnapshot(self, sDir="./movie"):
         oWidget = self.window.mainWidget
         oPixmap = oWidget.grab()
-        
+
         if not os.path.isdir(sDir):
             os.mkdir(sDir)
-        
+
         nRunIn = 30
         if self.iFrame > nRunIn:
             sfImage = "%s/frame%05d.jpg" % (sDir, self.iFrame - nRunIn)
             oPixmap.save(sfImage, "jpg")
         self.iFrame += 1
-        
diff --git a/flatland/utils/render_qt.py b/flatland/utils/render_qt.py
index e2334d3e1a8a66c31139b848b1b9bbb2f0e0a590..34e198566fdd8df8cff7e1822934e1f04cf3945c 100644
--- a/flatland/utils/render_qt.py
+++ b/flatland/utils/render_qt.py
@@ -80,20 +80,20 @@ class QTGL(GraphicsLayer):
             self.qtr.drawCircle(x, y, r)
 
     def text(self, x, y, sText):
-        self.qtr.drawText(x*self.cell_pixels, -y*self.cell_pixels, sText)
-    
+        self.qtr.drawText(x * self.cell_pixels, -y * self.cell_pixels, sText)
+
     def prettify(self, *args, **kwargs):
         pass
 
     def prettify2(self, width, height, cell_size):
         pass
-    
+
     def show(self, block=False):
         pass
 
     def pause(self, seconds=0.00001):
         pass
-    
+
     def clf(self):
         pass
 
@@ -104,9 +104,7 @@ class QTGL(GraphicsLayer):
         self.qtr.beginFrame()
         self.qtr.push()
         self.qtr.fillRect(0, 0, self.widthPx, self.heightPx, *self.tColBg)
-    
+
     def endFrame(self):
         self.qtr.pop()
         self.qtr.endFrame()
-
-
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 985f2b6dff207aad75412387e2f604fb3a80a1b2..fd414a7cb320f21160c6ed0126cba18fa75f1108 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -24,11 +24,11 @@ class MPLGL(GraphicsLayer):
 
     def text(self, *args, **kwargs):
         plt.text(*args, **kwargs)
-    
+
     def prettify(self, *args, **kwargs):
         ax = plt.gca()
-        plt.xticks(range(int(ax.get_xlim()[1])+1))
-        plt.yticks(range(int(ax.get_ylim()[1])+1))
+        plt.xticks(range(int(ax.get_xlim()[1]) + 1))
+        plt.yticks(range(int(ax.get_ylim()[1]) + 1))
         plt.grid()
         plt.xlabel("Euclidean distance")
         plt.ylabel("Tree / Transition Depth")
@@ -41,28 +41,28 @@ class MPLGL(GraphicsLayer):
         gLabels = np.arange(0, height)
         plt.xticks(gTicks, gLabels)
 
-        gTicks = np.arange(-height * cell_size, 0) + cell_size/2
-        gLabels = np.arange(height-1, -1, -1)
+        gTicks = np.arange(-height * cell_size, 0) + cell_size / 2
+        gLabels = np.arange(height - 1, -1, -1)
         plt.yticks(gTicks, gLabels)
 
         plt.xlim([0, width * cell_size])
         plt.ylim([-height * cell_size, 0])
-    
+
     def show(self, block=False):
         plt.show(block=block)
 
     def pause(self, seconds=0.00001):
         plt.pause(seconds)
-    
+
     def clf(self):
         plt.clf()
-    
+
     def get_cmap(self, *args, **kwargs):
         return plt.get_cmap(*args, **kwargs)
 
     def beginFrame(self):
         pass
-    
+
     def endFrame(self):
         pass
 
@@ -85,7 +85,7 @@ class RenderTool(object):
     gCentres = xr.DataArray(gGrid,
                             dims=["xy", "p1", "p2"],
                             coords={"xy": ["x", "y"]}) + xyPixHalf
-    gTheta = np.linspace(0, np.pi/2, 10)
+    gTheta = np.linspace(0, np.pi / 2, 10)
     gArc = array([np.cos(gTheta), np.sin(gTheta)]).T  # from [1,0] to [0,1]
 
     def __init__(self, env, gl="MPL"):
@@ -196,6 +196,8 @@ class RenderTool(object):
 
         rcDir = rt.gTransRC[iDir]                    # agent direction in RC
         xyDir = np.matmul(rcDir, rt.grc2xy)          # agent direction in xy
+        xyDirLine = array([xyPos, xyPos + xyDir / 2]).T  # line for agent orient.
+        self.gl.plot(*xyDirLine, color=sColor, lw=5, ms=0, alpha=0.6)
 
         xyPos = np.matmul(rcPos - rcDir / 2, rt.grc2xy) + rt.xyHalf
         self.gl.scatter(*xyPos, color=color, size=40)            # agent location
@@ -224,7 +226,7 @@ class RenderTool(object):
 
         rt = self.__class__
         xyPos = np.matmul(rcPos, rt.grc2xy) + rt.xyHalf
-        gxyTrans = xyPos + np.matmul(gTransRCAg, rt.grc2xy/2.4)
+        gxyTrans = xyPos + np.matmul(gTransRCAg, rt.grc2xy / 2.4)
         self.gl.scatter(*gxyTrans.T, color=color, marker="o", s=50, alpha=0.2)
         if depth is not None:
             for x, y in gxyTrans:
@@ -264,7 +266,7 @@ class RenderTool(object):
                     # print("Trans:", gTransRC2)
                     visitNext = rt.Visit(tuple(visit.rc + gTransRC2),
                                          iTrans,
-                                         visit.iDepth+1,
+                                         visit.iDepth + 1,
                                          visit)
                     # print("node2: ", node2)
                     stack.append(visitNext)
@@ -303,7 +305,7 @@ class RenderTool(object):
             xLoc = rDist + visit.iDir / 4
 
             # point labelled with distance
-            self.gl.scatter(xLoc, visit.iDepth,  color="k", s=2)
+            self.gl.scatter(xLoc, visit.iDepth, color="k", s=2)
             # plt.text(xLoc, visit.iDepth, sDist, color="k", rotation=45)
             self.gl.text(xLoc, visit.iDepth, visit.rc, color="k", rotation=45)
 
@@ -321,8 +323,8 @@ class RenderTool(object):
 
                 # line from prev node
                 self.gl.plot([xLocPrev, xLoc],
-                         [visit.iDepth-1, visit.iDepth],
-                         color="k", alpha=0.5, lw=1)
+                             [visit.iDepth - 1, visit.iDepth],
+                             color="k", alpha=0.5, lw=1)
 
             if rDist < 0.1:
                 visitDest = visit
@@ -335,8 +337,8 @@ class RenderTool(object):
                 rDist = np.linalg.norm(array(visit.rc) - array(xyTarg))
                 xLoc = rDist + visit.iDir / 4
                 if xLocPrev is not None:
-                    self.gl.plot([xLoc, xLocPrev], [visit.iDepth, visit.iDepth+1],
-                             color="r", alpha=0.5, lw=2)
+                    self.gl.plot([xLoc, xLocPrev], [visit.iDepth, visit.iDepth + 1],
+                                 color="r", alpha=0.5, lw=2)
                 xLocPrev = xLoc
                 visit = visit.prev
             # prev = prev.prev
@@ -369,13 +371,12 @@ class RenderTool(object):
 
                     self.gl.plot(*xyLine.T, color="r", alpha=0.5, lw=1)
 
-                    xyMid = np.sum(xyLine * [[1/4], [3/4]], axis=0)
+                    xyMid = np.sum(xyLine * [[1 / 4], [3 / 4]], axis=0)
 
                     xyArrow = array([
-                        xyMid + [-dx-dy, +dx-dy],
+                        xyMid + [-dx - dy, +dx - dy],
                         xyMid,
-                        xyMid + [-dx+dy, -dx-dy]
-                        ])
+                        xyMid + [-dx + dy, -dx - dy]])
                     self.gl.plot(*xyArrow.T, color="r")
 
                 visit = visit.prev
@@ -420,13 +421,12 @@ class RenderTool(object):
                 self.gl.plot(*xyLine2.T, color=sColor)
 
                 if bArrow:
-                    xyMid = np.sum(xyLine2 * [[1/4], [3/4]], axis=0)
+                    xyMid = np.sum(xyLine2 * [[1 / 4], [3 / 4]], axis=0)
 
                     xyArrow = array([
-                        xyMid + [-dx-dy, +dx-dy],
+                        xyMid + [-dx - dy, +dx - dy],
                         xyMid,
-                        xyMid + [-dx+dy, -dx-dy]
-                        ])
+                        xyMid + [-dx + dy, -dx - dy]])
                     self.gl.plot(*xyArrow.T, color=sColor)
 
         else:
@@ -452,10 +452,9 @@ class RenderTool(object):
                 iArc = int(len(rt.gArc) / 2)
                 xyMid = xyCorner + rt.gArc[iArc] * dxy2
                 xyArrow = array([
-                    xyMid + [-dx-dy, +dx-dy],
+                    xyMid + [-dx - dy, +dx - dy],
                     xyMid,
-                    xyMid + [-dx+dy, -dx-dy]
-                    ])
+                    xyMid + [-dx + dy, -dx - dy]])
                 self.gl.plot(*xyArrow.T, color=sColor)
 
     def renderEnv(
@@ -489,14 +488,14 @@ class RenderTool(object):
 
         # Draw cells grid
         grid_color = [0.95, 0.95, 0.95]
-        for r in range(env.height+1):
-            self.gl.plot([0, (env.width+1)*cell_size],
-                     [-r*cell_size, -r*cell_size],
-                     color=grid_color)
-        for c in range(env.width+1):
-            self.gl.plot([c*cell_size, c*cell_size],
-                     [0, -(env.height+1)*cell_size],
-                     color=grid_color)
+        for r in range(env.height + 1):
+            self.gl.plot([0, (env.width + 1) * cell_size],
+                         [-r * cell_size, -r * cell_size],
+                         color=grid_color)
+        for c in range(env.width + 1):
+            self.gl.plot([c * cell_size, c * cell_size],
+                         [0, -(env.height + 1) * cell_size],
+                         color=grid_color)
 
         # Draw each cell independently
         for r in range(env.height):
@@ -504,16 +503,16 @@ class RenderTool(object):
 
                 # bounding box of the grid cell
                 x0 = cell_size * c       # left
-                x1 = cell_size * (c+1)   # right
+                x1 = cell_size * (c + 1)   # right
                 y0 = cell_size * -r      # top
-                y1 = cell_size * -(r+1)  # bottom
+                y1 = cell_size * -(r + 1)  # bottom
 
                 # centres of cell edges
                 coords = [
-                    ((x0+x1)/2.0, y0),  # N middle top
-                    (x1, (y0+y1)/2.0),  # E middle right
-                    ((x0+x1)/2.0, y1),  # S middle bottom
-                    (x0, (y0+y1)/2.0)   # W middle left
+                    ((x0 + x1) / 2.0, y0),  # N middle top
+                    (x1, (y0 + y1) / 2.0),  # E middle right
+                    ((x0 + x1) / 2.0, y1),  # S middle bottom
+                    (x0, (y0 + y1) / 2.0)   # W middle left
                 ]
 
                 # cell centre
@@ -585,18 +584,18 @@ class RenderTool(object):
         if False:
             for i in range(env.number_of_agents):
                 self._draw_square((
-                                env.agents_position[i][1] *
-                                cell_size+cell_size/2,
-                                -env.agents_position[i][0] *
-                                cell_size-cell_size/2),
-                                cell_size/8, cmap(i))
+                    env.agents_position[i][1] *
+                    cell_size + cell_size / 2,
+                    -env.agents_position[i][0] *
+                    cell_size - cell_size / 2),
+                    cell_size / 8, cmap(i))
             for i in range(env.number_of_agents):
                 self._draw_square((
-                                env.agents_target[i][1] *
-                                cell_size+cell_size/2,
-                                -env.agents_target[i][0] *
-                                cell_size-cell_size/2),
-                                cell_size/3, [c for c in cmap(i)])
+                    env.agents_target[i][1] *
+                    cell_size + cell_size / 2,
+                    -env.agents_target[i][0] *
+                    cell_size - cell_size / 2),
+                    cell_size / 3, [c for c in cmap(i)])
 
                 # orientation is a line connecting the center of the cell to the
                 # side of the square of the agent
@@ -606,8 +605,8 @@ class RenderTool(object):
                     (new_position[1] + env.agents_position[i][1]) / 2 * cell_size)
 
                 self.gl.plot(
-                    [env.agents_position[i][1] * cell_size+cell_size/2, new_position[1]+cell_size/2],
-                    [-env.agents_position[i][0] * cell_size-cell_size/2, -new_position[0]-cell_size/2],
+                    [env.agents_position[i][1] * cell_size + cell_size / 2, new_position[1] + cell_size / 2],
+                    [-env.agents_position[i][0] * cell_size - cell_size / 2, -new_position[0] - cell_size / 2],
                     color=cmap(i),
                     linewidth=2.0)
 
@@ -616,7 +615,7 @@ class RenderTool(object):
         if frames:
             self.gl.text(0.1, yText[2], "Frame:{:}".format(self.iFrame))
         self.iFrame += 1
-        
+
         if iEpisode is not None:
             self.gl.text(0.1, yText[1], "Ep:{}".format(iEpisode))
 
@@ -643,8 +642,8 @@ class RenderTool(object):
         return
 
     def _draw_square(self, center, size, color):
-        x0 = center[0]-size/2
-        x1 = center[0]+size/2
-        y0 = center[1]-size/2
-        y1 = center[1]+size/2
+        x0 = center[0] - size / 2
+        x1 = center[0] + size / 2
+        y0 = center[1] - size / 2
+        y1 = center[1] + size / 2
         self.gl.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color)
diff --git a/tests/test_env_observation_builder.py b/tests/test_env_observation_builder.py
index db264c2975b75f32c2612aa19c0511076460ec6b..55c229e88e73c311ae8f8f4aeee01218cf1dd4cf 100644
--- a/tests/test_env_observation_builder.py
+++ b/tests/test_env_observation_builder.py
@@ -46,14 +46,14 @@ def test_global_obs():
         double_switch_south_horizontal_straight, 180)
 
     rail_map = np.array(
-               [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
-               [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
-               [[dead_end_from_east] + [horizontal_straight] * 2 +
-                [double_switch_north_horizontal_straight] +
-                [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
-                [horizontal_straight] * 2 + [dead_end_from_west]] +
-               [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
-               [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
+        [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
+        [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
+        [[dead_end_from_east] + [horizontal_straight] * 2 +
+         [double_switch_north_horizontal_straight] +
+         [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
+         [horizontal_straight] * 2 + [dead_end_from_west]] +
+        [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
+        [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
 
     rail = GridTransitionMap(width=rail_map.shape[1],
                              height=rail_map.shape[0], transitions=transitions)
diff --git a/tests/test_environments.py b/tests/test_environments.py
index ea8748b8aa4b50a1371a013be98f3b42d0d01228..210f1c76c8fd9978141a48189d5bcf2e31e68611 100644
--- a/tests/test_environments.py
+++ b/tests/test_environments.py
@@ -30,8 +30,7 @@ def test_rail_environment_single_agent():
     transitions = Grid4Transitions([])
     vertical_line = cells[1]
     south_symmetrical_switch = cells[6]
-    north_symmetrical_switch = transitions.rotate_transition(
-                                south_symmetrical_switch, 180)
+    north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180)
     # Simple turn not in the base transitions ?
     south_east_turn = int('0100000000000010', 2)
     south_west_turn = transitions.rotate_transition(south_east_turn, 90)
diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py
index 528cc59bd8e66b9d383d093c7dd6363e9dc45f71..1f5c317965a3101c6232709ebb311959d5a566ed 100644
--- a/tests/test_rendertools.py
+++ b/tests/test_rendertools.py
@@ -11,7 +11,7 @@ import os
 import matplotlib.pyplot as plt
 
 import flatland.utils.rendertools as rt
-from flatland.core.env_observation_builder import GlobalObsForRailEnv, TreeObsForRailEnv
+from flatland.core.env_observation_builder import TreeObsForRailEnv
 
 
 def checkFrozenImage(sFileImage):
@@ -31,7 +31,7 @@ def checkFrozenImage(sFileImage):
             bytesFrozenImage = bytesImage
         else:
             assert(bytesFrozenImage.shape == bytesImage.shape)
-            assert((np.sum(np.square(bytesFrozenImage-bytesImage)) / bytesFrozenImage.size) < 1e-3)
+            assert((np.sum(np.square(bytesFrozenImage - bytesImage)) / bytesFrozenImage.size) < 1e-3)
 
 
 def test_render_env():
diff --git a/tests/test_transitions.py b/tests/test_transitions.py
index f68b836e58b87078ef8fcf799ca089df0d09d292..0f56e886071fd1d217be03b9a7e875c20d1a0e8a 100644
--- a/tests/test_transitions.py
+++ b/tests/test_transitions.py
@@ -15,65 +15,60 @@ def test_valid_railenv_transitions():
 
     for i in range(2):
         assert(rail_env_trans.get_transitions(
-                    int('1100110000110011', 2), i) == (1, 1, 0, 0))
+               int('1100110000110011', 2), i) == (1, 1, 0, 0))
         assert(rail_env_trans.get_transitions(
-                    int('1100110000110011', 2), 2+i) == (0, 0, 1, 1))
+               int('1100110000110011', 2), 2 + i) == (0, 0, 1, 1))
 
     no_transition_cell = int('0000000000000000', 2)
 
     for i in range(4):
         assert(rail_env_trans.get_transitions(
-                    no_transition_cell, i) == (0, 0, 0, 0))
+               no_transition_cell, i) == (0, 0, 0, 0))
 
     # Facing south, going south
-    north_south_transition = rail_env_trans.set_transitions(
-                    no_transition_cell, 2, (0, 0, 1, 0))
+    north_south_transition = rail_env_trans.set_transitions(no_transition_cell, 2, (0, 0, 1, 0))
     assert(rail_env_trans.set_transition(
-                    north_south_transition, 2, 2, 0) == no_transition_cell)
+           north_south_transition, 2, 2, 0) == no_transition_cell)
     assert(rail_env_trans.get_transition(
-                    north_south_transition, 2, 2))
+           north_south_transition, 2, 2))
 
     # Facing north, going east
     south_east_transition = \
-        rail_env_trans.set_transition(
-         no_transition_cell, 0, 1, 1)
+        rail_env_trans.set_transition(no_transition_cell, 0, 1, 1)
     assert(rail_env_trans.get_transition(
-            south_east_transition, 0, 1))
+           south_east_transition, 0, 1))
 
     # The opposite transitions are not feasible
     assert(not rail_env_trans.get_transition(
-            north_south_transition, 2, 0))
+           north_south_transition, 2, 0))
     assert(not rail_env_trans.get_transition(
-            south_east_transition, 2, 1))
+           south_east_transition, 2, 1))
 
-    east_west_transition = rail_env_trans.rotate_transition(
-            north_south_transition, 90)
-    north_west_transition = rail_env_trans.rotate_transition(
-            south_east_transition, 180)
+    east_west_transition = rail_env_trans.rotate_transition(north_south_transition, 90)
+    north_west_transition = rail_env_trans.rotate_transition(south_east_transition, 180)
 
     # Facing west, going west
     assert(rail_env_trans.get_transition(
-            east_west_transition, 3, 3))
+           east_west_transition, 3, 3))
     # Facing south, going west
     assert(rail_env_trans.get_transition(
-            north_west_transition, 2, 3))
+           north_west_transition, 2, 3))
 
     assert(south_east_transition == rail_env_trans.rotate_transition(
-            south_east_transition, 360))
+           south_east_transition, 360))
 
 
 def test_diagonal_transitions():
     diagonal_trans_env = Grid8Transitions([])
 
     # Facing north, going north-east
-    south_northeast_transition = int('01000000' + '0'*8*7, 2)
+    south_northeast_transition = int('01000000' + '0' * 8 * 7, 2)
     assert(diagonal_trans_env.get_transitions(
-            south_northeast_transition, 0) == (0, 1, 0, 0, 0, 0, 0, 0))
+           south_northeast_transition, 0) == (0, 1, 0, 0, 0, 0, 0, 0))
 
     # Allowing transition from north to southwest: Facing south, going SW
     north_southwest_transition = \
-        diagonal_trans_env.set_transitions(
-         int('0' * 64, 2), 4, (0, 0, 0, 0, 0, 1, 0, 0))
+        diagonal_trans_env.set_transitions(int('0' * 64, 2), 4, (0, 0, 0, 0, 0, 1, 0, 0))
 
     assert(diagonal_trans_env.rotate_transition(
-            south_northeast_transition, 180) == north_southwest_transition)
+           south_northeast_transition, 180) == north_southwest_transition)