From 12138cdbd9524aa7571ed4ff2a01b472530a75ed Mon Sep 17 00:00:00 2001
From: spiglerg <spiglerg@gmail.com>
Date: Thu, 16 May 2019 14:27:11 +0200
Subject: [PATCH] refactored env observations ISSUE #24 + fixed most pylint
 errors

---
 examples/play_model.py                   |   1 -
 examples/temporary_example.py            |   2 +-
 examples/training_navigation.py          |   2 +-
 flatland/core/env_observation_builder.py | 512 ----------------------
 flatland/envs/agent_utils.py             |   5 +-
 flatland/envs/env_utils.py               |   2 +-
 flatland/envs/generators.py              |   2 +-
 flatland/envs/observations.py            | 515 +++++++++++++++++++++++
 flatland/envs/rail_env.py                |  26 +-
 flatland/utils/editor.py                 |  28 +-
 flatland/utils/graphics_layer.py         |   4 +-
 flatland/utils/graphics_pil.py           |   2 +-
 flatland/utils/render_qt.py              |  12 +-
 flatland/utils/rendertools.py            |  30 +-
 flatland/utils/svg.py                    |  23 +-
 tests/test_environments.py               |   2 +-
 tests/test_rendertools.py                |   7 +-
 17 files changed, 582 insertions(+), 593 deletions(-)
 create mode 100644 flatland/envs/observations.py

diff --git a/examples/play_model.py b/examples/play_model.py
index f80d9a1..62726c2 100644
--- a/examples/play_model.py
+++ b/examples/play_model.py
@@ -1,6 +1,5 @@
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.generators import complex_rail_generator
-# from flatland.core.env_observation_builder import TreeObsForRailEnv
 from flatland.utils.rendertools import RenderTool
 from flatland.baselines.dueling_double_dqn import Agent
 from collections import deque
diff --git a/examples/temporary_example.py b/examples/temporary_example.py
index 0ed2f62..1f3504f 100644
--- a/examples/temporary_example.py
+++ b/examples/temporary_example.py
@@ -4,7 +4,7 @@ import matplotlib.pyplot as plt
 
 from flatland.envs.rail_env import *
 from flatland.envs.generators import *
-from flatland.core.env_observation_builder import TreeObsForRailEnv
+from flatland.envs.observations import TreeObsForRailEnv
 from flatland.utils.rendertools import *
 
 random.seed(0)
diff --git a/examples/training_navigation.py b/examples/training_navigation.py
index 9d45cd1..20d093e 100644
--- a/examples/training_navigation.py
+++ b/examples/training_navigation.py
@@ -1,6 +1,6 @@
 from flatland.envs.rail_env import *
 from flatland.envs.generators import *
-from flatland.core.env_observation_builder import TreeObsForRailEnv
+from flatland.envs.observations import TreeObsForRailEnv
 from flatland.utils.rendertools import *
 from flatland.baselines.dueling_double_dqn import Agent
 from collections import deque
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index 8e7f2ae..09a624e 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -8,10 +8,6 @@ The ObservationBuilder-derived custom classes implement 2 functions, reset() and
 case of multi-agent environments.
 """
 
-import numpy as np
-
-from collections import deque
-
 
 class ObservationBuilder:
     """
@@ -46,511 +42,3 @@ class ObservationBuilder:
             An observation structure, specific to the corresponding environment.
         """
         raise NotImplementedError()
-
-
-class TreeObsForRailEnv(ObservationBuilder):
-    """
-    TreeObsForRailEnv object.
-
-    This object returns observation vectors for agents in the RailEnv environment.
-    The information is local to each agent and exploits the tree structure of the rail
-    network to simplify the representation of the state of the environment for each agent.
-    """
-
-    def __init__(self, max_depth):
-        self.max_depth = max_depth
-
-    def reset(self):
-        agents = self.env.agents
-        nAgents = len(agents)
-        self.distance_map = np.inf * np.ones(shape=(nAgents,  # self.env.number_of_agents,
-                                                    self.env.height,
-                                                    self.env.width,
-                                                    4))
-        self.max_dist = np.zeros(nAgents)
-
-        # for i in range(nAgents):
-        #     self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i)
-        self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
-
-        # Update local lookup table for all agents' target locations
-        self.location_has_target = {}
-        # for loc in self.env.agents_target:
-        #    self.location_has_target[(loc[0], loc[1])] = 1
-        self.location_has_target = {tuple(agent.target): 1 for agent in agents}
-
-    def _distance_map_walker(self, position, target_nr):
-        """
-        Utility function to compute distance maps from each cell in the rail network (and each possible
-        orientation within it) to each agent's target cell.
-        """
-        # Returns max distance to target, from the farthest away node, while filling in distance_map
-
-        self.distance_map[target_nr, position[0], position[1], :] = 0
-
-        # Fill in the (up to) 4 neighboring nodes
-        # nodes_queue = []  # list of tuples (row, col, direction, distance);
-        # direction is the direction of movement, meaning that at least a possible orientation of an agent
-        # in cell (row,col) allows a movement in direction `direction'
-        nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1))
-
-        # BFS from target `position' to all the reachable nodes in the grid
-        # Stop the search if the target position is re-visited, in any direction
-        visited = set([(position[0], position[1], 0),
-                       (position[0], position[1], 1),
-                       (position[0], position[1], 2),
-                       (position[0], position[1], 3)])
-
-        max_distance = 0
-
-        while nodes_queue:
-            node = nodes_queue.popleft()
-
-            node_id = (node[0], node[1], node[2])
-
-            if node_id not in visited:
-                visited.add(node_id)
-
-                # From the list of possible neighbors that have at least a path to the current node, only keep those
-                # whose new orientation in the current cell would allow a transition to direction node[2]
-                valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2])
-
-                for n in valid_neighbors:
-                    nodes_queue.append(n)
-
-                if len(valid_neighbors) > 0:
-                    max_distance = max(max_distance, node[3] + 1)
-
-        return max_distance
-
-    def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1):
-        """
-        Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
-        minimum distances from each target cell.
-        """
-        neighbors = []
-
-        possible_directions = [0, 1, 2, 3]
-        if enforce_target_direction >= 0:
-            # The agent must land into the current cell with orientation `enforce_target_direction'.
-            # This is only possible if the agent has arrived from the cell in the opposite direction!
-            possible_directions = [(enforce_target_direction + 2) % 4]
-
-        for neigh_direction in possible_directions:
-            new_cell = self._new_position(position, neigh_direction)
-
-            if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width:
-
-                desired_movement_from_new_cell = (neigh_direction + 2) % 4
-
-                """
-                # Is the next cell a dead-end?
-                isNextCellDeadEnd = False
-                nbits = 0
-                tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1]))
-                while tmp > 0:
-                    nbits += (tmp & 1)
-                    tmp = tmp >> 1
-                if nbits == 1:
-                    # Dead-end!
-                    isNextCellDeadEnd = True
-                """
-
-                # Check all possible transitions in new_cell
-                for agent_orientation in range(4):
-                    # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
-                    isValid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
-                                                           desired_movement_from_new_cell)
-
-                    if isValid:
-                        """
-                        # TODO: check that it works with deadends! -- still bugged!
-                        movement = desired_movement_from_new_cell
-                        if isNextCellDeadEnd:
-                            movement = (desired_movement_from_new_cell+2) % 4
-                        """
-                        new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
-                                           current_distance + 1)
-                        neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
-                        self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
-
-        return neighbors
-
-    def _new_position(self, position, movement):
-        """
-        Utility function that converts a compass movement over a 2D grid to new positions (r, c).
-        """
-        if movement == 0:  # NORTH
-            return (position[0] - 1, position[1])
-        elif movement == 1:  # EAST
-            return (position[0], position[1] + 1)
-        elif movement == 2:  # SOUTH
-            return (position[0] + 1, position[1])
-        elif movement == 3:  # WEST
-            return (position[0], position[1] - 1)
-
-    def get(self, handle):
-        """
-        Computes the current observation for agent `handle' in env
-
-        The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
-        movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
-        The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for
-        the transitions. The order is:
-            [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']
-
-
-
-
-
-        Each branch data is organized as:
-            [root node information] +
-            [recursive branch data from 'left'] +
-            [... from 'forward'] +
-            [... from 'right] +
-            [... from 'back']
-
-        Finally, each node information is composed of 5 floating point values:
-
-        #1:
-
-        #2: 1 if a target of another agent is detected between the previous node and the current one.
-
-        #3: 1 if another agent is detected between the previous node and the current one.
-
-        #4: distance of agent to the current branch node
-
-        #5: minimum distance from node to the agent's target (when landing to the node following the corresponding
-            branch.
-
-        Missing/padding nodes are filled in with -inf (truncated).
-        Missing values in present node are filled in with +inf (truncated).
-
-
-        In case of the root node, the values are [0, 0, 0, 0, distance from agent to target].
-        In case the target node is reached, the values are [0, 0, 0, 0, 0].
-        """
-
-        # Update local lookup table for all agents' positions
-        # self.location_has_agent = {}
-        # for loc in self.env.agents_position:
-        #    self.location_has_agent[(loc[0], loc[1])] = 1
-        self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents}
-
-        agent = self.env.agents[handle]  # TODO: handle being treated as index
-        # position = self.env.agents_position[handle]
-        # orientation = self.env.agents_direction[handle]
-        possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
-        num_transitions = np.count_nonzero(possible_transitions)
-        # Root node - current position
-        # observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
-        observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)]]
-        root_observation = observation[:]
-
-        # Start from the current orientation, and see which transitions are available;
-        # organize them as [left, forward, right, back], relative to the current orientation
-        # If only one transition is possible, the tree is oriented with this transition as the forward branch.
-        # TODO: Test if this works as desired!
-        orientation = agent.direction
-        if num_transitions == 1:
-            orientation == np.argmax(possible_transitions)
-
-        # for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]:
-        for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
-            if possible_transitions[branch_direction]:
-                new_cell = self._new_position(agent.position, branch_direction)
-
-                branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1)
-                observation = observation + branch_observation
-            else:
-                num_cells_to_fill_in = 0
-                pow4 = 1
-                for i in range(self.max_depth):
-                    num_cells_to_fill_in += pow4
-                    pow4 *= 4
-                observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
-        return observation
-
-    def _explore_branch(self, handle, position, direction, root_observation, depth):
-        """
-        Utility function to compute tree-based observations.
-        """
-        # [Recursive branch opened]
-        if depth >= self.max_depth + 1:
-            return []
-
-        # Continue along direction until next switch or
-        # until no transitions are possible along the current direction (i.e., dead-ends)
-        # We treat dead-ends as nodes, instead of going back, to avoid loops
-        exploring = True
-        last_isSwitch = False
-        last_isDeadEnd = False
-        last_isTerminal = False  # wrong cell OR cycle;  either way, we don't want the agent to land here
-        last_isTarget = False
-
-        visited = set()
-
-        # other_agent_encountered = False
-        # other_target_encountered = False
-        other_agent_encountered = np.inf
-        other_target_encountered = np.inf
-
-        num_steps = 1
-        while exploring:
-            # #############################
-            # #############################
-            # Modify here to compute any useful data required to build the end node's features. This code is called
-            # for each cell visited between the previous branching node and the next switch / target / dead-end.
-            if position in self.location_has_agent:
-                # other_agent_encountered = True
-                if num_steps < other_agent_encountered:
-                    other_agent_encountered = num_steps
-
-            if position in self.location_has_target:
-                # other_target_encountered = True
-                if num_steps < other_target_encountered:
-                    other_target_encountered = num_steps
-            # #############################
-            # #############################
-
-            if (position[0], position[1], direction) in visited:
-                last_isTerminal = True
-                break
-            visited.add((position[0], position[1], direction))
-
-            # If the target node is encountered, pick that as node. Also, no further branching is possible.
-            # if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
-            if np.array_equal(position, self.env.agents[handle].target):
-                last_isTarget = True
-                break
-
-            cell_transitions = self.env.rail.get_transitions((*position, direction))
-            num_transitions = np.count_nonzero(cell_transitions)
-            exploring = False
-            if num_transitions == 1:
-                # Check if dead-end, or if we can go forward along direction
-                nbits = 0
-                tmp = self.env.rail.get_transitions(tuple(position))
-                while tmp > 0:
-                    nbits += (tmp & 1)
-                    tmp = tmp >> 1
-                if nbits == 1:
-                    # Dead-end!
-                    last_isDeadEnd = True
-
-                if not last_isDeadEnd:
-                    # Keep walking through the tree along `direction'
-                    exploring = True
-                    direction = np.argmax(cell_transitions)
-                    position = self._new_position(position, direction)
-                    num_steps += 1
-
-            elif num_transitions > 0:
-                # Switch detected
-                last_isSwitch = True
-                break
-
-            elif num_transitions == 0:
-                # Wrong cell type, but let's cover it and treat it as a dead-end, just in case
-                print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0],
-                      position[1], direction)
-                last_isTerminal = True
-                break
-
-        # `position' is either a terminal node or a switch
-
-        observation = []
-
-        # #############################
-        # #############################
-        # Modify here to append new / different features for each visited cell!
-        """
-        if last_isTarget:
-            observation = [0,
-                           1 if other_target_encountered else 0,
-                           1 if other_agent_encountered else 0,
-                           root_observation[3] + num_steps,
-                           0]
-
-        elif last_isTerminal:
-            observation = [0,
-                           1 if other_target_encountered else 0,
-                           1 if other_agent_encountered else 0,
-                           np.inf,
-                           np.inf]
-        else:
-            observation = [0,
-                           1 if other_target_encountered else 0,
-                           1 if other_agent_encountered else 0,
-                           root_observation[3] + num_steps,
-                           self.distance_map[handle, position[0], position[1], direction]]
-        """
-        if last_isTarget:
-            observation = [0,
-                           other_target_encountered,
-                           other_agent_encountered,
-                           root_observation[3] + num_steps,
-                           0]
-
-        elif last_isTerminal:
-            observation = [0,
-                           other_target_encountered,
-                           other_agent_encountered,
-                           np.inf,
-                           np.inf]
-        else:
-            observation = [0,
-                           other_target_encountered,
-                           other_agent_encountered,
-                           root_observation[3] + num_steps,
-                           self.distance_map[handle, position[0], position[1], direction]]
-        # #############################
-        # #############################
-
-        new_root_observation = observation[:]
-        # Start from the current orientation, and see which transitions are available;
-        # organize them as [left, forward, right, back], relative to the current orientation
-        # Get the possible transitions
-        possible_transitions = self.env.rail.get_transitions((*position, direction))
-        for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
-            if last_isDeadEnd and self.env.rail.get_transition((*position, direction),
-                                                               (branch_direction + 2) % 4):
-                # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
-                # it back
-                new_cell = self._new_position(position, (branch_direction + 2) % 4)
-                branch_observation = self._explore_branch(handle,
-                                                          new_cell,
-                                                          (branch_direction + 2) % 4,
-                                                          new_root_observation,
-                                                          depth + 1)
-                observation = observation + branch_observation
-
-            elif last_isSwitch and possible_transitions[branch_direction]:
-                new_cell = self._new_position(position, branch_direction)
-                branch_observation = self._explore_branch(handle,
-                                                          new_cell,
-                                                          branch_direction,
-                                                          new_root_observation,
-                                                          depth + 1)
-                observation = observation + branch_observation
-
-            else:
-                num_cells_to_fill_in = 0
-                pow4 = 1
-                for i in range(self.max_depth - depth):
-                    num_cells_to_fill_in += pow4
-                    pow4 *= 4
-                observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
-
-        return observation
-
-    def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0):
-        """
-        Utility function to pretty-print tree observations returned by this object.
-        """
-        if len(tree) < num_features_per_node:
-            return
-
-        depth = 0
-        tmp = len(tree) / num_features_per_node - 1
-        pow4 = 4
-        while tmp > 0:
-            tmp -= pow4
-            depth += 1
-            pow4 *= 4
-
-        prompt_ = ['L:', 'F:', 'R:', 'B:']
-
-        print("  " * current_depth + prompt, tree[0:num_features_per_node])
-        child_size = (len(tree) - num_features_per_node) // 4
-        for children in range(4):
-            child_tree = tree[(num_features_per_node + children * child_size):
-                              (num_features_per_node + (children + 1) * child_size)]
-            self.util_print_obs_subtree(child_tree,
-                                        num_features_per_node,
-                                        prompt=prompt_[children],
-                                        current_depth=current_depth + 1)
-
-    def split_tree(self, tree, num_features_per_node=5, current_depth=0):
-        """
-
-        :param tree:
-        :param num_features_per_node:
-        :param prompt:
-        :param current_depth:
-        :return:
-        """
-
-        if len(tree) < num_features_per_node:
-            return [], []
-
-        depth = 0
-        tmp = len(tree) / num_features_per_node - 1
-        pow4 = 4
-        while tmp > 0:
-            tmp -= pow4
-            depth += 1
-            pow4 *= 4
-        child_size = (len(tree) - num_features_per_node) // 4
-        tree_data = tree[0:num_features_per_node - 1].tolist()
-        distance_data = [tree[num_features_per_node - 1]]
-        for children in range(4):
-            child_tree = tree[(num_features_per_node + children * child_size):
-                              (num_features_per_node + (children + 1) * child_size)]
-            tmp_tree_data, tmp_distance_data = self.split_tree(child_tree,
-                                                               num_features_per_node,
-                                                               current_depth=current_depth + 1)
-            if len(tmp_tree_data) > 0:
-                tree_data.extend(tmp_tree_data)
-                distance_data.extend(tmp_distance_data)
-        return tree_data, distance_data
-
-
-class GlobalObsForRailEnv(ObservationBuilder):
-    """
-    Gives a global observation of the entire rail environment.
-    The observation is composed of the following elements:
-
-        - transition map array with dimensions (env.height, env.width, 16),
-          assuming 16 bits encoding of transitions.
-
-        - Four 2D arrays containing respectively the position of the given agent,
-          the position of its target, the positions of the other agents and of
-          their target.
-
-        - A 4 elements array with one of encoding of the direction.
-    """
-
-    def __init__(self):
-        super(GlobalObsForRailEnv, self).__init__()
-
-    def reset(self):
-        self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
-        for i in range(self.rail_obs.shape[0]):
-            for j in range(self.rail_obs.shape[1]):
-                self.rail_obs[i, j] = np.array(
-                    list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int)
-
-        # self.targets = np.zeros(self.env.height, self.env.width)
-        # for target_pos in self.env.agents_target:
-        #     self.targets[target_pos] += 1
-
-    def get(self, handle):
-        obs = np.zeros((4, self.env.height, self.env.width))
-        agents = self.env.agents
-        agent = agents[handle]
-
-        agent_pos = agents[handle].position
-        obs[0][agent_pos] += 1
-        obs[1][agent.target] += 1
-
-        for i in range(len(agents)):
-            if i != handle:  # TODO: handle used as index...?
-                agent2 = agents[i]
-                obs[3][agent2.position] += 1
-                obs[2][agent2.target] += 1
-
-        direction = np.zeros(4)
-        direction[agent.direction] = 1
-
-        return self.rail_obs, obs, direction
diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index 0d42ab7..1f1bc1d 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -53,7 +53,7 @@ class EnvAgent(EnvAgentStatic):
     def __init__(self, position, direction, target, handle, old_direction):
         super(EnvAgent, self).__init__(position, direction, target)
         self.handle = handle
-        self.old_direction = old_direction 
+        self.old_direction = old_direction
 
     def to_list(self):
         return [self.position, self.direction, self.target, self.handle, self.old_direction]
@@ -72,7 +72,6 @@ class EnvAgent(EnvAgentStatic):
         """
         if handles is None:
             handles = range(len(lEnvAgentStatic))
-            
+
         return [EnvAgent(**oEAS.__dict__, handle=handle)
                 for handle, oEAS in zip(handles, lEnvAgentStatic)]
-
diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py
index a1e46db..b58604c 100644
--- a/flatland/envs/env_utils.py
+++ b/flatland/envs/env_utils.py
@@ -8,7 +8,7 @@ a GridTransitionMap object.
 import numpy as np
 
 # from flatland.core.env import Environment
-# from flatland.core.env_observation_builder import TreeObsForRailEnv
+# from flatland.envs.observations import TreeObsForRailEnv
 
 # from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
 # from flatland.core.transition_map import GridTransitionMap
diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index e9e2d3d..04e9a8f 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -1,7 +1,7 @@
 import numpy as np
 
 # from flatland.core.env import Environment
-# from flatland.core.env_observation_builder import TreeObsForRailEnv
+# from flatland.envs.observations import TreeObsForRailEnv
 
 from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
new file mode 100644
index 0000000..8e4be0b
--- /dev/null
+++ b/flatland/envs/observations.py
@@ -0,0 +1,515 @@
+"""
+Collection of environment-specific ObservationBuilder.
+"""
+import numpy as np
+from collections import deque
+
+from flatland.core.env_observation_builder import ObservationBuilder
+
+
+class TreeObsForRailEnv(ObservationBuilder):
+    """
+    TreeObsForRailEnv object.
+
+    This object returns observation vectors for agents in the RailEnv environment.
+    The information is local to each agent and exploits the tree structure of the rail
+    network to simplify the representation of the state of the environment for each agent.
+    """
+
+    def __init__(self, max_depth):
+        self.max_depth = max_depth
+
+    def reset(self):
+        agents = self.env.agents
+        nAgents = len(agents)
+        self.distance_map = np.inf * np.ones(shape=(nAgents,  # self.env.number_of_agents,
+                                                    self.env.height,
+                                                    self.env.width,
+                                                    4))
+        self.max_dist = np.zeros(nAgents)
+
+        # for i in range(nAgents):
+        #     self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i)
+        self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
+
+        # Update local lookup table for all agents' target locations
+        self.location_has_target = {}
+        # for loc in self.env.agents_target:
+        #    self.location_has_target[(loc[0], loc[1])] = 1
+        self.location_has_target = {tuple(agent.target): 1 for agent in agents}
+
+    def _distance_map_walker(self, position, target_nr):
+        """
+        Utility function to compute distance maps from each cell in the rail network (and each possible
+        orientation within it) to each agent's target cell.
+        """
+        # Returns max distance to target, from the farthest away node, while filling in distance_map
+
+        self.distance_map[target_nr, position[0], position[1], :] = 0
+
+        # Fill in the (up to) 4 neighboring nodes
+        # nodes_queue = []  # list of tuples (row, col, direction, distance);
+        # direction is the direction of movement, meaning that at least a possible orientation of an agent
+        # in cell (row,col) allows a movement in direction `direction'
+        nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1))
+
+        # BFS from target `position' to all the reachable nodes in the grid
+        # Stop the search if the target position is re-visited, in any direction
+        visited = set([(position[0], position[1], 0),
+                       (position[0], position[1], 1),
+                       (position[0], position[1], 2),
+                       (position[0], position[1], 3)])
+
+        max_distance = 0
+
+        while nodes_queue:
+            node = nodes_queue.popleft()
+
+            node_id = (node[0], node[1], node[2])
+
+            if node_id not in visited:
+                visited.add(node_id)
+
+                # From the list of possible neighbors that have at least a path to the current node, only keep those
+                # whose new orientation in the current cell would allow a transition to direction node[2]
+                valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2])
+
+                for n in valid_neighbors:
+                    nodes_queue.append(n)
+
+                if len(valid_neighbors) > 0:
+                    max_distance = max(max_distance, node[3] + 1)
+
+        return max_distance
+
+    def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1):
+        """
+        Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
+        minimum distances from each target cell.
+        """
+        neighbors = []
+
+        possible_directions = [0, 1, 2, 3]
+        if enforce_target_direction >= 0:
+            # The agent must land into the current cell with orientation `enforce_target_direction'.
+            # This is only possible if the agent has arrived from the cell in the opposite direction!
+            possible_directions = [(enforce_target_direction + 2) % 4]
+
+        for neigh_direction in possible_directions:
+            new_cell = self._new_position(position, neigh_direction)
+
+            if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width:
+
+                desired_movement_from_new_cell = (neigh_direction + 2) % 4
+
+                """
+                # Is the next cell a dead-end?
+                isNextCellDeadEnd = False
+                nbits = 0
+                tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1]))
+                while tmp > 0:
+                    nbits += (tmp & 1)
+                    tmp = tmp >> 1
+                if nbits == 1:
+                    # Dead-end!
+                    isNextCellDeadEnd = True
+                """
+
+                # Check all possible transitions in new_cell
+                for agent_orientation in range(4):
+                    # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
+                    isValid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
+                                                           desired_movement_from_new_cell)
+
+                    if isValid:
+                        """
+                        # TODO: check that it works with deadends! -- still bugged!
+                        movement = desired_movement_from_new_cell
+                        if isNextCellDeadEnd:
+                            movement = (desired_movement_from_new_cell+2) % 4
+                        """
+                        new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
+                                           current_distance + 1)
+                        neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
+                        self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
+
+        return neighbors
+
+    def _new_position(self, position, movement):
+        """
+        Utility function that converts a compass movement over a 2D grid to new positions (r, c).
+        """
+        if movement == 0:  # NORTH
+            return (position[0] - 1, position[1])
+        elif movement == 1:  # EAST
+            return (position[0], position[1] + 1)
+        elif movement == 2:  # SOUTH
+            return (position[0] + 1, position[1])
+        elif movement == 3:  # WEST
+            return (position[0], position[1] - 1)
+
+    def get(self, handle):
+        """
+        Computes the current observation for agent `handle' in env
+
+        The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
+        movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
+        The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for
+        the transitions. The order is:
+            [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']
+
+
+
+
+
+        Each branch data is organized as:
+            [root node information] +
+            [recursive branch data from 'left'] +
+            [... from 'forward'] +
+            [... from 'right] +
+            [... from 'back']
+
+        Finally, each node information is composed of 5 floating point values:
+
+        #1:
+
+        #2: 1 if a target of another agent is detected between the previous node and the current one.
+
+        #3: 1 if another agent is detected between the previous node and the current one.
+
+        #4: distance of agent to the current branch node
+
+        #5: minimum distance from node to the agent's target (when landing to the node following the corresponding
+            branch.
+
+        Missing/padding nodes are filled in with -inf (truncated).
+        Missing values in present node are filled in with +inf (truncated).
+
+
+        In case of the root node, the values are [0, 0, 0, 0, distance from agent to target].
+        In case the target node is reached, the values are [0, 0, 0, 0, 0].
+        """
+
+        # Update local lookup table for all agents' positions
+        # self.location_has_agent = {}
+        # for loc in self.env.agents_position:
+        #    self.location_has_agent[(loc[0], loc[1])] = 1
+        self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents}
+
+        agent = self.env.agents[handle]  # TODO: handle being treated as index
+        # position = self.env.agents_position[handle]
+        # orientation = self.env.agents_direction[handle]
+        possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
+        num_transitions = np.count_nonzero(possible_transitions)
+        # Root node - current position
+        # observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
+        observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)]]
+        root_observation = observation[:]
+
+        # Start from the current orientation, and see which transitions are available;
+        # organize them as [left, forward, right, back], relative to the current orientation
+        # If only one transition is possible, the tree is oriented with this transition as the forward branch.
+        # TODO: Test if this works as desired!
+        orientation = agent.direction
+        if num_transitions == 1:
+            orientation == np.argmax(possible_transitions)
+
+        # for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]:
+        for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
+            if possible_transitions[branch_direction]:
+                new_cell = self._new_position(agent.position, branch_direction)
+
+                branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1)
+                observation = observation + branch_observation
+            else:
+                num_cells_to_fill_in = 0
+                pow4 = 1
+                for i in range(self.max_depth):
+                    num_cells_to_fill_in += pow4
+                    pow4 *= 4
+                observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
+        return observation
+
+    def _explore_branch(self, handle, position, direction, root_observation, depth):
+        """
+        Utility function to compute tree-based observations.
+        """
+        # [Recursive branch opened]
+        if depth >= self.max_depth + 1:
+            return []
+
+        # Continue along direction until next switch or
+        # until no transitions are possible along the current direction (i.e., dead-ends)
+        # We treat dead-ends as nodes, instead of going back, to avoid loops
+        exploring = True
+        last_isSwitch = False
+        last_isDeadEnd = False
+        last_isTerminal = False  # wrong cell OR cycle;  either way, we don't want the agent to land here
+        last_isTarget = False
+
+        visited = set()
+
+        # other_agent_encountered = False
+        # other_target_encountered = False
+        other_agent_encountered = np.inf
+        other_target_encountered = np.inf
+
+        num_steps = 1
+        while exploring:
+            # #############################
+            # #############################
+            # Modify here to compute any useful data required to build the end node's features. This code is called
+            # for each cell visited between the previous branching node and the next switch / target / dead-end.
+            if position in self.location_has_agent:
+                # other_agent_encountered = True
+                if num_steps < other_agent_encountered:
+                    other_agent_encountered = num_steps
+
+            if position in self.location_has_target:
+                # other_target_encountered = True
+                if num_steps < other_target_encountered:
+                    other_target_encountered = num_steps
+            # #############################
+            # #############################
+
+            if (position[0], position[1], direction) in visited:
+                last_isTerminal = True
+                break
+            visited.add((position[0], position[1], direction))
+
+            # If the target node is encountered, pick that as node. Also, no further branching is possible.
+            # if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
+            if np.array_equal(position, self.env.agents[handle].target):
+                last_isTarget = True
+                break
+
+            cell_transitions = self.env.rail.get_transitions((*position, direction))
+            num_transitions = np.count_nonzero(cell_transitions)
+            exploring = False
+            if num_transitions == 1:
+                # Check if dead-end, or if we can go forward along direction
+                nbits = 0
+                tmp = self.env.rail.get_transitions(tuple(position))
+                while tmp > 0:
+                    nbits += (tmp & 1)
+                    tmp = tmp >> 1
+                if nbits == 1:
+                    # Dead-end!
+                    last_isDeadEnd = True
+
+                if not last_isDeadEnd:
+                    # Keep walking through the tree along `direction'
+                    exploring = True
+                    direction = np.argmax(cell_transitions)
+                    position = self._new_position(position, direction)
+                    num_steps += 1
+
+            elif num_transitions > 0:
+                # Switch detected
+                last_isSwitch = True
+                break
+
+            elif num_transitions == 0:
+                # Wrong cell type, but let's cover it and treat it as a dead-end, just in case
+                print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0],
+                      position[1], direction)
+                last_isTerminal = True
+                break
+
+        # `position' is either a terminal node or a switch
+
+        observation = []
+
+        # #############################
+        # #############################
+        # Modify here to append new / different features for each visited cell!
+        """
+        if last_isTarget:
+            observation = [0,
+                           1 if other_target_encountered else 0,
+                           1 if other_agent_encountered else 0,
+                           root_observation[3] + num_steps,
+                           0]
+
+        elif last_isTerminal:
+            observation = [0,
+                           1 if other_target_encountered else 0,
+                           1 if other_agent_encountered else 0,
+                           np.inf,
+                           np.inf]
+        else:
+            observation = [0,
+                           1 if other_target_encountered else 0,
+                           1 if other_agent_encountered else 0,
+                           root_observation[3] + num_steps,
+                           self.distance_map[handle, position[0], position[1], direction]]
+        """
+        if last_isTarget:
+            observation = [0,
+                           other_target_encountered,
+                           other_agent_encountered,
+                           root_observation[3] + num_steps,
+                           0]
+
+        elif last_isTerminal:
+            observation = [0,
+                           other_target_encountered,
+                           other_agent_encountered,
+                           np.inf,
+                           np.inf]
+        else:
+            observation = [0,
+                           other_target_encountered,
+                           other_agent_encountered,
+                           root_observation[3] + num_steps,
+                           self.distance_map[handle, position[0], position[1], direction]]
+        # #############################
+        # #############################
+
+        new_root_observation = observation[:]
+        # Start from the current orientation, and see which transitions are available;
+        # organize them as [left, forward, right, back], relative to the current orientation
+        # Get the possible transitions
+        possible_transitions = self.env.rail.get_transitions((*position, direction))
+        for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
+            if last_isDeadEnd and self.env.rail.get_transition((*position, direction),
+                                                               (branch_direction + 2) % 4):
+                # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
+                # it back
+                new_cell = self._new_position(position, (branch_direction + 2) % 4)
+                branch_observation = self._explore_branch(handle,
+                                                          new_cell,
+                                                          (branch_direction + 2) % 4,
+                                                          new_root_observation,
+                                                          depth + 1)
+                observation = observation + branch_observation
+
+            elif last_isSwitch and possible_transitions[branch_direction]:
+                new_cell = self._new_position(position, branch_direction)
+                branch_observation = self._explore_branch(handle,
+                                                          new_cell,
+                                                          branch_direction,
+                                                          new_root_observation,
+                                                          depth + 1)
+                observation = observation + branch_observation
+
+            else:
+                num_cells_to_fill_in = 0
+                pow4 = 1
+                for i in range(self.max_depth - depth):
+                    num_cells_to_fill_in += pow4
+                    pow4 *= 4
+                observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
+
+        return observation
+
+    def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0):
+        """
+        Utility function to pretty-print tree observations returned by this object.
+        """
+        if len(tree) < num_features_per_node:
+            return
+
+        depth = 0
+        tmp = len(tree) / num_features_per_node - 1
+        pow4 = 4
+        while tmp > 0:
+            tmp -= pow4
+            depth += 1
+            pow4 *= 4
+
+        prompt_ = ['L:', 'F:', 'R:', 'B:']
+
+        print("  " * current_depth + prompt, tree[0:num_features_per_node])
+        child_size = (len(tree) - num_features_per_node) // 4
+        for children in range(4):
+            child_tree = tree[(num_features_per_node + children * child_size):
+                              (num_features_per_node + (children + 1) * child_size)]
+            self.util_print_obs_subtree(child_tree,
+                                        num_features_per_node,
+                                        prompt=prompt_[children],
+                                        current_depth=current_depth + 1)
+
+    def split_tree(self, tree, num_features_per_node=5, current_depth=0):
+        """
+
+        :param tree:
+        :param num_features_per_node:
+        :param prompt:
+        :param current_depth:
+        :return:
+        """
+
+        if len(tree) < num_features_per_node:
+            return [], []
+
+        depth = 0
+        tmp = len(tree) / num_features_per_node - 1
+        pow4 = 4
+        while tmp > 0:
+            tmp -= pow4
+            depth += 1
+            pow4 *= 4
+        child_size = (len(tree) - num_features_per_node) // 4
+        tree_data = tree[0:num_features_per_node - 1].tolist()
+        distance_data = [tree[num_features_per_node - 1]]
+        for children in range(4):
+            child_tree = tree[(num_features_per_node + children * child_size):
+                              (num_features_per_node + (children + 1) * child_size)]
+            tmp_tree_data, tmp_distance_data = self.split_tree(child_tree,
+                                                               num_features_per_node,
+                                                               current_depth=current_depth + 1)
+            if len(tmp_tree_data) > 0:
+                tree_data.extend(tmp_tree_data)
+                distance_data.extend(tmp_distance_data)
+        return tree_data, distance_data
+
+
+class GlobalObsForRailEnv(ObservationBuilder):
+    """
+    Gives a global observation of the entire rail environment.
+    The observation is composed of the following elements:
+
+        - transition map array with dimensions (env.height, env.width, 16),
+          assuming 16 bits encoding of transitions.
+
+        - Four 2D arrays containing respectively the position of the given agent,
+          the position of its target, the positions of the other agents and of
+          their target.
+
+        - A 4 elements array with one of encoding of the direction.
+    """
+
+    def __init__(self):
+        super(GlobalObsForRailEnv, self).__init__()
+
+    def reset(self):
+        self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
+        for i in range(self.rail_obs.shape[0]):
+            for j in range(self.rail_obs.shape[1]):
+                self.rail_obs[i, j] = np.array(
+                    list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int)
+
+        # self.targets = np.zeros(self.env.height, self.env.width)
+        # for target_pos in self.env.agents_target:
+        #     self.targets[target_pos] += 1
+
+    def get(self, handle):
+        obs = np.zeros((4, self.env.height, self.env.width))
+        agents = self.env.agents
+        agent = agents[handle]
+
+        agent_pos = agents[handle].position
+        obs[0][agent_pos] += 1
+        obs[1][agent.target] += 1
+
+        for i in range(len(agents)):
+            if i != handle:  # TODO: handle used as index...?
+                agent2 = agents[i]
+                obs[3][agent2.position] += 1
+                obs[2][agent2.target] += 1
+
+        direction = np.zeros(4)
+        direction[agent.direction] = 1
+
+        return self.rail_obs, obs, direction
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 228cc32..a5248a8 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -8,7 +8,7 @@ import numpy as np
 import msgpack
 
 from flatland.core.env import Environment
-from flatland.core.env_observation_builder import TreeObsForRailEnv
+from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.generators import random_rail_generator
 from flatland.envs.env_utils import get_new_position
 from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
@@ -157,7 +157,7 @@ class RailEnv(Environment):
         #    self.dones[handle] = False
         self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
         # perhaps dones should be part of each agent.
-        
+
         # Reset the state of the observation builder with the new environment
         self.obs_builder.reset()
 
@@ -215,7 +215,7 @@ class RailEnv(Environment):
                 # 1) transition allows the new_direction in the cell,
                 # 2) the new cell is not empty (case 0),
                 # 3) the cell is free, i.e., no agent is currently in that cell
-                
+
                 # if (
                 #        new_position[1] >= self.width or
                 #        new_position[0] >= self.height or
@@ -226,11 +226,11 @@ class RailEnv(Environment):
                 #     new_cell_isValid = False
 
                 new_cell_isValid = (
-                        np.array_equal(  # Check the new position is still in the grid
-                            new_position,
-                            np.clip(new_position, [0, 0], [self.height-1, self.width-1]))
-                        and  # check the new position has some transitions (ie is not an empty cell)
-                        self.rail.get_transitions(new_position) > 0)
+                    np.array_equal(  # Check the new position is still in the grid
+                        new_position,
+                        np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
+                    and  # check the new position has some transitions (ie is not an empty cell)
+                    self.rail.get_transitions(new_position) > 0)
 
                 # If transition validity hasn't been checked yet.
                 if transition_isValid is None:
@@ -246,7 +246,7 @@ class RailEnv(Environment):
                 # Check the new position is not the same as any of the existing agent positions
                 # (including itself, for simplicity, since it is moving)
                 cell_isFree = not np.any(
-                        np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
+                    np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
 
                 if all([new_cell_isValid, transition_isValid, cell_isFree]):
                     # move and change direction to face the new_direction that was
@@ -278,7 +278,7 @@ class RailEnv(Environment):
         # if num_agents_in_target_position == self.number_of_agents:
         if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
             self.dones["__all__"] = True
-            self.rewards_dict = [0*r+global_reward for r in self.rewards_dict]
+            self.rewards_dict = [0 * r + global_reward for r in self.rewards_dict]
 
         # Reset the step actions (in case some agent doesn't 'register_action'
         # on the next step)
@@ -331,15 +331,13 @@ class RailEnv(Environment):
         msg_data = {
             "grid": grid_data,
             "agents_static": agent_static_data,
-            "agents": agent_data
-            }
+            "agents": agent_data}
         return msgpack.packb(msg_data, use_bin_type=True)
 
     def get_agent_state_msg(self):
         agent_data = [agent.to_list() for agent in self.agents]
         msg_data = {
-            "agents": agent_data
-            }
+            "agents": agent_data}
         return msgpack.packb(msg_data, use_bin_type=True)
 
     def set_full_state_msg(self, msg_data):
diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py
index 3c7f470..a33a011 100644
--- a/flatland/utils/editor.py
+++ b/flatland/utils/editor.py
@@ -17,7 +17,7 @@ import os
 from flatland.envs.rail_env import RailEnv, random_rail_generator
 from flatland.envs.generators import complex_rail_generator
 # from flatland.core.transitions import RailEnvTransitions
-from flatland.core.env_observation_builder import TreeObsForRailEnv
+from flatland.envs.observations import TreeObsForRailEnv
 import flatland.utils.rendertools as rt
 from examples.play_model import Player
 from flatland.envs.env_utils import mirror
@@ -110,8 +110,7 @@ class View(object):
             self.wTab.set_title(i, title)
         self.wTab.children = [
             VBox([self.wDebug, self.wDebug_move]),
-            VBox([self.wRegenMethod, self.wReplaceAgents])
-            ]
+            VBox([self.wRegenMethod, self.wReplaceAgents])]
 
         # Progress bar intended for stepping in the background (not yet working)
         self.wProg_steps = ipywidgets.IntProgress(value=0, min=0, max=20, step=1, description="Step")
@@ -121,21 +120,20 @@ class View(object):
             dict(name="Refresh", method=self.controller.refresh, tip="Redraw only"),
             dict(name="Clear", method=self.controller.clear, tip="Clear rails and agents"),
             dict(name="Reset", method=self.controller.reset,
-                tip="Standard env reset, including regen rail + agents"),
+                 tip="Standard env reset, including regen rail + agents"),
             dict(name="Restart Agents", method=self.controller.restartAgents,
-                tip="Move agents back to start positions"),
+                 tip="Move agents back to start positions"),
             dict(name="Regenerate", method=self.controller.regenerate,
-                tip="Regenerate the rails using the method selected below"),
+                 tip="Regenerate the rails using the method selected below"),
             dict(name="Load", method=self.controller.load),
             dict(name="Save", method=self.controller.save),
             dict(name="Step", method=self.controller.step),
-            dict(name="Run Steps", method=self.controller.start_run),
-        ]
+            dict(name="Run Steps", method=self.controller.start_run)]
 
         self.lwButtons = []
         for dButton in ldButtons:
             wButton = ipywidgets.Button(description=dButton["name"],
-                tooltip=dButton["tip"] if "tip" in dButton else dButton["name"])
+                                        tooltip=dButton["tip"] if "tip" in dButton else dButton["name"])
             wButton.on_click(dButton["method"])
             self.lwButtons.append(wButton)
 
@@ -145,8 +143,7 @@ class View(object):
             self.wSize,
             self.wNAgents,
             self.wProg_steps,
-            self.wTab
-            ])
+            self.wTab])
 
         self.wMain = HBox([self.wImage, self.wVbox_controls])
 
@@ -164,7 +161,7 @@ class View(object):
         with self.wOutput:
             # plt.figure(figsize=(10, 10))
             self.oRT.renderEnv(spacing=False, arrows=False, sRailColor="gray",
-                show=False, iSelectedAgent=self.model.iSelectedAgent)
+                               show=False, iSelectedAgent=self.model.iSelectedAgent)
             img = self.oRT.getImage()
             # plt.clf()
             # plt.close()
@@ -607,7 +604,7 @@ class EditorModel(object):
 
         # Has the user clicked on an existing agent?
         iAgent = self.find_agent_at(rcCell)
-        
+
         if iAgent is None:
             # No
             if self.iSelectedAgent is None:
@@ -656,7 +653,7 @@ class EditorModel(object):
                 # self.log("step ", i)
                 self.step()
                 time.sleep(0.2)
-                wProg_steps.value = i+1   # indicate progress on bar
+                wProg_steps.value = i + 1   # indicate progress on bar
         finally:
             self.thread = None
 
@@ -683,6 +680,3 @@ class EditorModel(object):
                    binTrans,
                    sbinTrans,
                    [sbinTrans[i:(i + 4)] for i in range(0, len(sbinTrans), 4)])
-
-
-    
\ No newline at end of file
diff --git a/flatland/utils/graphics_layer.py b/flatland/utils/graphics_layer.py
index c0d390c..4cfcc64 100644
--- a/flatland/utils/graphics_layer.py
+++ b/flatland/utils/graphics_layer.py
@@ -54,7 +54,7 @@ class GraphicsLayer(object):
                 color = tuple((gcolor[:3] * 255).astype(int))
         else:
             color = self.tColGrid
-        
+
         if lighten:
             color = tuple([int(255 - (255 - iRGB) / 3) for iRGB in color])
 
@@ -65,6 +65,6 @@ class GraphicsLayer(object):
 
     def setRailAt(self, row, col, binTrans):
         pass
-    
+
     def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut):
         pass
diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py
index 01cc5f0..41516fd 100644
--- a/flatland/utils/graphics_pil.py
+++ b/flatland/utils/graphics_pil.py
@@ -39,7 +39,7 @@ class PILGL(GraphicsLayer):
         r = np.sqrt(s)
         gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.nPixCell
         for x, y in gPoints:
-            self.draw.rectangle([(x-r, y-r), (x+r, y+r)], fill=color, outline=color)
+            self.draw.rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color)
 
     def text(self, *args, **kwargs):
         pass
diff --git a/flatland/utils/render_qt.py b/flatland/utils/render_qt.py
index af40c05..ea96139 100644
--- a/flatland/utils/render_qt.py
+++ b/flatland/utils/render_qt.py
@@ -111,7 +111,7 @@ class QTSVG(GraphicsLayer):
         self.wWinMain.resize(1000, 1000)
         self.wWinMain.show()
         self.wWinMain.setFocus()
-        
+
         self.track = self.track = Track()
         self.lwTrack = []
         self.zug = Zug()
@@ -152,21 +152,21 @@ class QTSVG(GraphicsLayer):
     def processEvents(self):
         self.app.processEvents()
         time.sleep(0.001)
-    
+
     def clear_rails(self):
         # print("Clear rails: ", len(self.lwTrack))
         for wRail in self.lwTrack:
             self.layout.removeWidget(wRail)
         self.lwTrack = []
         self.clear_agents()
-    
+
     def clear_agents(self):
         # print("Clear Agents: ", len(self.lwAgents))
         for wAgent in self.lwAgents:
             self.layout.removeWidget(wAgent)
         self.lwAgents = []
         self.agents_prev = []
-        
+
     def setRailAt(self, row, col, binTrans):
         if binTrans in self.track.dSvg:
             sSVG = self.track.dSvg[binTrans].to_string()
@@ -204,7 +204,7 @@ class QTSVG(GraphicsLayer):
                     return
 
         # Ensure we have adequate slots in the list lwAgents
-        for i in range(len(self.lwAgents), iAgent+1):
+        for i in range(len(self.lwAgents), iAgent + 1):
             self.lwAgents.append(None)
             self.agents_prev.append(None)
 
@@ -226,7 +226,7 @@ def main2():
     gl = QTGL(10, 10)
     for i in range(10):
         gl.beginFrame()
-        gl.plot([3+i, 4], [-4-i, -5], color="r")
+        gl.plot([3 + i, 4], [-4 - i, -5], color="r")
         gl.endFrame()
         time.sleep(1)
 
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 4921def..82d8475 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -130,7 +130,7 @@ class RenderTool(object):
             self.gl = PILGL(env.width, env.height)
         elif gl == "QTSVG":
             self.gl = QTSVG(env.width, env.height)
-        
+
         self.new_rail = True
 
     def set_new_rail(self):
@@ -153,14 +153,14 @@ class RenderTool(object):
 
     def plotAgents(self, targets=True, iSelectedAgent=None):
         cmap = self.gl.get_cmap('hsv',
-            lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
+                                lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
 
         for iAgent, agent in enumerate(self.env.agents_static):
             if agent is None:
                 continue
             oColor = cmap(iAgent)
             self.plotAgent(agent.position, agent.direction, oColor, target=agent.target if targets else None,
-                static=True, selected=iAgent == iSelectedAgent)
+                           static=True, selected=iAgent == iSelectedAgent)
 
         for iAgent, agent in enumerate(self.env.agents):
             if agent is None:
@@ -488,9 +488,9 @@ class RenderTool(object):
 
         if not self.gl.is_raster():
             self.renderEnv2(show, curves, spacing,
-            arrows, agents, sRailColor,
-            frames, iEpisode, iStep,
-            iSelectedAgent, action_dict)
+                            arrows, agents, sRailColor,
+                            frames, iEpisode, iStep,
+                            iSelectedAgent, action_dict)
             return
 
         # cell_size is a bit pointless with matplotlib - it does not relate to pixels,
@@ -667,16 +667,16 @@ class RenderTool(object):
         for i in range(nDepth):
             nDepthNodes = nBranchFactor**i
             # rScale = nBranchFactor ** (nDepth - i)
-            rShrinkDepth = 1/(i+1)
+            rShrinkDepth = 1 / (i + 1)
             # gX1 = np.linspace(-nDepthNodes / 2, nDepthNodes / 2, nDepthNodes) * rShrinkDepth
-            
-            gX1 = np.linspace(-(nDepthNodes-1), (nDepthNodes-1), nDepthNodes) * rShrinkDepth
+
+            gX1 = np.linspace(-(nDepthNodes - 1), (nDepthNodes - 1), nDepthNodes) * rShrinkDepth
             gY1 = np.ones((nDepthNodes)) * i
             gZ1 = np.zeros((nDepthNodes))
-            
+
             gP1 = array([gX1, gY1, gZ1])
             gP01 = np.append(gP0, gP1, axis=1)
-            
+
             if nDepthNodes > 1:
                 nDepthNodesPrev = nDepthNodes / nBranchFactor
                 giP0 = np.repeat(np.arange(nDepthNodesPrev), nBranchFactor)
@@ -687,7 +687,7 @@ class RenderTool(object):
                 self.gl.plot(gP01[0], -gP01[1], lines=giLinePoints, color="gray")
 
             gP0 = array([gX1, gY1, gZ1])
-    
+
     def renderEnv2(
             self, show=False, curves=True, spacing=False,
             arrows=False, agents=True, sRailColor="gray",
@@ -715,7 +715,7 @@ class RenderTool(object):
                     self.gl.setRailAt(r, c, binTrans)
 
         cmap = self.gl.get_cmap('hsv',
-            lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
+                                lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
 
         for iAgent, agent in enumerate(self.env.agents):
             if agent is None:
@@ -729,14 +729,14 @@ class RenderTool(object):
             if iAgent in action_dict:
                 iAction = action_dict[iAgent]
                 new_direction, action_isValid = self.env.check_action(agent, iAction)
-            
+
             if action_isValid:
                 self.gl.setAgentAt(iAgent, *agent.position, agent.direction, new_direction, color=oColor)
             else:
                 pass
                 # print("invalid action - agent ", iAgent, " bend ", agent.direction, new_direction)
                 # self.gl.setAgentAt(iAgent, *agent.position, agent.direction, new_direction)
-                
+
         self.gl.show()
         for i in range(3):
             self.gl.processEvents()
diff --git a/flatland/utils/svg.py b/flatland/utils/svg.py
index b219560..32d5631 100644
--- a/flatland/utils/svg.py
+++ b/flatland/utils/svg.py
@@ -30,10 +30,10 @@ class SVG(object):
 
     def merge(self, svg2):
         svg3 = svg2.copy()
-        
+
         svg3.renumber_styles(offset=10)
         svg3.eStyle.text = self.eStyle.text + "\n" + svg3.eStyle.text
-        
+
         for child in self.svg.root:
             if not child.tag.endswith("style"):
                 svg3.svg.root.append(child)
@@ -48,12 +48,12 @@ class SVG(object):
             lEl = self.svg.root.xpath("//*[@class='{}']".format(sClass))
             for el in lEl:
                 el.attrib["class"] = "st{}".format(iStyle + offset)
-        
-            sStyle2 = str(iStyle+offset)
+
+            sStyle2 = str(iStyle + offset)
 
             sNewStyle = "\t.st" + sStyle2 + "{" + self.dStyles[sStyle] + "}\n"
             sNewStyles += sNewStyle
-        
+
         self.eStyle.text = sNewStyles
 
     def set_style_color(self, style_name, color):
@@ -63,7 +63,7 @@ class SVG(object):
                 sValue = "fill:#" + "".join([format(col, "#04x")[2:] for col in color]) + ";"
             sNewStyle = "\t.st" + sKey + "{" + sValue + "}\n"
             sNewStyles += sNewStyle
-        
+
         self.eStyle.text = sNewStyles
 
     def set_rotate(self, angle):
@@ -87,7 +87,7 @@ class Zug(object):
         if delta_dir in (0, 2):
             svg = self.svg_straight.copy()
             svg.set_rotate(iDirIn * 90)
-        
+
         if delta_dir == 1:  # bend to right, eg N->E, E->S
             svg = self.svg_curve1.copy()
             svg.set_rotate((iDirIn - 1) * 90)
@@ -124,8 +124,7 @@ class Track(object):
             "NN SS EN SW": "Weiche_vertikal_oben_links.svg",
             "NN SS SE WN": "Weiche_vertikal_oben_rechts.svg",
             "NN SS NW ES": "Weiche_vertikal_unten_links.svg",
-            "NN SS NE WS": "Weiche_vertikal_unten_rechts.svg",
-            }
+            "NN SS NE WS": "Weiche_vertikal_unten_rechts.svg"}
 
         self.dSvg = {}
 
@@ -136,8 +135,7 @@ class Track(object):
         svgBG = SVG("./svg/Background_#91D1DD.svg")
 
         for sTrans, sFile in dFiles.items():
-            
-            svg = SVG("./svg/"+sFile)
+            svg = SVG("./svg/" + sFile)
 
             lTrans16 = ["0"] * 16
             for sTran in sTrans.split(" "):
@@ -165,7 +163,7 @@ class Track(object):
 def main():
     # svg1 = SVG("./svg/Gleis_vertikal.svg")
     # svg2 = SVG("./svg/Zug_1_Weiche_#0091ea.svg")
-    
+
     # svg3 = svg2.merge(svg1)
     # svg3.set_rotate(90)
 
@@ -188,4 +186,3 @@ def main():
 
 if __name__ == "__main__":
     main()
-
diff --git a/tests/test_environments.py b/tests/test_environments.py
index dae7c13..4c55eac 100644
--- a/tests/test_environments.py
+++ b/tests/test_environments.py
@@ -230,4 +230,4 @@ def test_dead_end():
 
 if __name__ == "__main__":
     test_rail_environment_single_agent()
-    test_dead_end()
\ No newline at end of file
+    test_dead_end()
diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py
index ea4ad3c..3259ed3 100644
--- a/tests/test_rendertools.py
+++ b/tests/test_rendertools.py
@@ -6,13 +6,12 @@ Tests for `flatland` package.
 
 from flatland.envs.rail_env import RailEnv, random_rail_generator
 import numpy as np
-import os
 import sys
 
 import matplotlib.pyplot as plt
 
 import flatland.utils.rendertools as rt
-from flatland.core.env_observation_builder import TreeObsForRailEnv
+from flatland.envs.observations import TreeObsForRailEnv
 
 
 def checkFrozenImage(oRT, sFileImage, resave=False):
@@ -49,7 +48,7 @@ def test_render_env(save_new_images=False):
     oEnv.rail.load_transition_map(sfTestEnv)
     oRT = rt.RenderTool(oEnv)
     oRT.renderEnv()
-    
+
     checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images)
 
     oRT = rt.RenderTool(oEnv, gl="PIL")
@@ -86,4 +85,4 @@ def main():
 
 
 if __name__ == "__main__":
-    main()
\ No newline at end of file
+    main()
-- 
GitLab