diff --git a/Makefile b/Makefile
index 43d86866d18f6270b5a88be7492b7ab24885b193..e23fd6b899d9396945aabec2fdbf6ad6269e9f51 100644
--- a/Makefile
+++ b/Makefile
@@ -71,7 +71,8 @@ docs: ## generate Sphinx HTML documentation, including API docs
 	sphinx-apidoc -o docs/ flatland
 	$(MAKE) -C docs clean
 	$(MAKE) -C docs html
-	export HOME=$(pwd)
+	# N.B. HOME variable required by pydeps!
+	export HOME=${PWD}
 	python3 -m pydeps flatland -o docs/_build/html/flatland.svg
 	$(BROWSER) docs/_build/html/index.html
 
diff --git a/examples/play_model.py b/examples/play_model.py
index f80d9a14c54c1fa411a8736e4891ba5b95c188d8..62726c24c96be0e5dae2f4840e18da452163b7ac 100644
--- a/examples/play_model.py
+++ b/examples/play_model.py
@@ -1,6 +1,5 @@
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.generators import complex_rail_generator
-# from flatland.core.env_observation_builder import TreeObsForRailEnv
 from flatland.utils.rendertools import RenderTool
 from flatland.baselines.dueling_double_dqn import Agent
 from collections import deque
diff --git a/examples/temporary_example.py b/examples/temporary_example.py
index 0ed2f6207683b0983a6c8a9783c6677834437bd1..1f3504f221d59d0205974bb135c2237364a22e07 100644
--- a/examples/temporary_example.py
+++ b/examples/temporary_example.py
@@ -4,7 +4,7 @@ import matplotlib.pyplot as plt
 
 from flatland.envs.rail_env import *
 from flatland.envs.generators import *
-from flatland.core.env_observation_builder import TreeObsForRailEnv
+from flatland.envs.observations import TreeObsForRailEnv
 from flatland.utils.rendertools import *
 
 random.seed(0)
diff --git a/examples/training_navigation.py b/examples/training_navigation.py
index 9d45cd175c7d324804a2a34e278428aebad69e28..23970a9059f15ce89b59b1cbcb9b862d248d1cd5 100644
--- a/examples/training_navigation.py
+++ b/examples/training_navigation.py
@@ -1,6 +1,6 @@
 from flatland.envs.rail_env import *
 from flatland.envs.generators import *
-from flatland.core.env_observation_builder import TreeObsForRailEnv
+from flatland.envs.observations import TreeObsForRailEnv
 from flatland.utils.rendertools import *
 from flatland.baselines.dueling_double_dqn import Agent
 from collections import deque
@@ -32,8 +32,8 @@ env = RailEnv(width=10,
 """
 env = RailEnv(width=15,
               height=15,
-              rail_generator=complex_rail_generator(nr_start_goal=5, min_dist=5, max_dist=99999, seed=0),
-              number_of_agents=5)
+              rail_generator=complex_rail_generator(nr_start_goal=10, min_dist=5, max_dist=99999, seed=0),
+              number_of_agents=3)
 """
 env = RailEnv(width=20,
               height=20,
@@ -50,7 +50,7 @@ action_size = 4
 n_trials = 15000
 eps = 1.
 eps_end = 0.005
-eps_decay = 0.998
+eps_decay = 0.9995
 action_dict = dict()
 final_action_dict = dict()
 scores_window = deque(maxlen=100)
@@ -62,9 +62,9 @@ action_prob = [0] * 4
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size, "FC", 0)
-# agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth'))
+agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth'))
 
-demo = False
+demo = True
 
 
 def max_lt(seq, val):
diff --git a/flatland/baselines/Nets/avoid_checkpoint15000.pth b/flatland/baselines/Nets/avoid_checkpoint15000.pth
index adcfe61576553bbf0e2b4ba00d9fffafbfd9d7da..14882a37a86085b137f4422b6bba75f387a2d3b5 100644
Binary files a/flatland/baselines/Nets/avoid_checkpoint15000.pth and b/flatland/baselines/Nets/avoid_checkpoint15000.pth differ
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index 8e7f2ae57f5e27b586365e870e2acbc8666b912e..09a624e872200e29ede834d272f7de506d6de076 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -8,10 +8,6 @@ The ObservationBuilder-derived custom classes implement 2 functions, reset() and
 case of multi-agent environments.
 """
 
-import numpy as np
-
-from collections import deque
-
 
 class ObservationBuilder:
     """
@@ -46,511 +42,3 @@ class ObservationBuilder:
             An observation structure, specific to the corresponding environment.
         """
         raise NotImplementedError()
-
-
-class TreeObsForRailEnv(ObservationBuilder):
-    """
-    TreeObsForRailEnv object.
-
-    This object returns observation vectors for agents in the RailEnv environment.
-    The information is local to each agent and exploits the tree structure of the rail
-    network to simplify the representation of the state of the environment for each agent.
-    """
-
-    def __init__(self, max_depth):
-        self.max_depth = max_depth
-
-    def reset(self):
-        agents = self.env.agents
-        nAgents = len(agents)
-        self.distance_map = np.inf * np.ones(shape=(nAgents,  # self.env.number_of_agents,
-                                                    self.env.height,
-                                                    self.env.width,
-                                                    4))
-        self.max_dist = np.zeros(nAgents)
-
-        # for i in range(nAgents):
-        #     self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i)
-        self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
-
-        # Update local lookup table for all agents' target locations
-        self.location_has_target = {}
-        # for loc in self.env.agents_target:
-        #    self.location_has_target[(loc[0], loc[1])] = 1
-        self.location_has_target = {tuple(agent.target): 1 for agent in agents}
-
-    def _distance_map_walker(self, position, target_nr):
-        """
-        Utility function to compute distance maps from each cell in the rail network (and each possible
-        orientation within it) to each agent's target cell.
-        """
-        # Returns max distance to target, from the farthest away node, while filling in distance_map
-
-        self.distance_map[target_nr, position[0], position[1], :] = 0
-
-        # Fill in the (up to) 4 neighboring nodes
-        # nodes_queue = []  # list of tuples (row, col, direction, distance);
-        # direction is the direction of movement, meaning that at least a possible orientation of an agent
-        # in cell (row,col) allows a movement in direction `direction'
-        nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1))
-
-        # BFS from target `position' to all the reachable nodes in the grid
-        # Stop the search if the target position is re-visited, in any direction
-        visited = set([(position[0], position[1], 0),
-                       (position[0], position[1], 1),
-                       (position[0], position[1], 2),
-                       (position[0], position[1], 3)])
-
-        max_distance = 0
-
-        while nodes_queue:
-            node = nodes_queue.popleft()
-
-            node_id = (node[0], node[1], node[2])
-
-            if node_id not in visited:
-                visited.add(node_id)
-
-                # From the list of possible neighbors that have at least a path to the current node, only keep those
-                # whose new orientation in the current cell would allow a transition to direction node[2]
-                valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2])
-
-                for n in valid_neighbors:
-                    nodes_queue.append(n)
-
-                if len(valid_neighbors) > 0:
-                    max_distance = max(max_distance, node[3] + 1)
-
-        return max_distance
-
-    def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1):
-        """
-        Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
-        minimum distances from each target cell.
-        """
-        neighbors = []
-
-        possible_directions = [0, 1, 2, 3]
-        if enforce_target_direction >= 0:
-            # The agent must land into the current cell with orientation `enforce_target_direction'.
-            # This is only possible if the agent has arrived from the cell in the opposite direction!
-            possible_directions = [(enforce_target_direction + 2) % 4]
-
-        for neigh_direction in possible_directions:
-            new_cell = self._new_position(position, neigh_direction)
-
-            if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width:
-
-                desired_movement_from_new_cell = (neigh_direction + 2) % 4
-
-                """
-                # Is the next cell a dead-end?
-                isNextCellDeadEnd = False
-                nbits = 0
-                tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1]))
-                while tmp > 0:
-                    nbits += (tmp & 1)
-                    tmp = tmp >> 1
-                if nbits == 1:
-                    # Dead-end!
-                    isNextCellDeadEnd = True
-                """
-
-                # Check all possible transitions in new_cell
-                for agent_orientation in range(4):
-                    # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
-                    isValid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
-                                                           desired_movement_from_new_cell)
-
-                    if isValid:
-                        """
-                        # TODO: check that it works with deadends! -- still bugged!
-                        movement = desired_movement_from_new_cell
-                        if isNextCellDeadEnd:
-                            movement = (desired_movement_from_new_cell+2) % 4
-                        """
-                        new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
-                                           current_distance + 1)
-                        neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
-                        self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
-
-        return neighbors
-
-    def _new_position(self, position, movement):
-        """
-        Utility function that converts a compass movement over a 2D grid to new positions (r, c).
-        """
-        if movement == 0:  # NORTH
-            return (position[0] - 1, position[1])
-        elif movement == 1:  # EAST
-            return (position[0], position[1] + 1)
-        elif movement == 2:  # SOUTH
-            return (position[0] + 1, position[1])
-        elif movement == 3:  # WEST
-            return (position[0], position[1] - 1)
-
-    def get(self, handle):
-        """
-        Computes the current observation for agent `handle' in env
-
-        The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
-        movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
-        The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for
-        the transitions. The order is:
-            [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']
-
-
-
-
-
-        Each branch data is organized as:
-            [root node information] +
-            [recursive branch data from 'left'] +
-            [... from 'forward'] +
-            [... from 'right] +
-            [... from 'back']
-
-        Finally, each node information is composed of 5 floating point values:
-
-        #1:
-
-        #2: 1 if a target of another agent is detected between the previous node and the current one.
-
-        #3: 1 if another agent is detected between the previous node and the current one.
-
-        #4: distance of agent to the current branch node
-
-        #5: minimum distance from node to the agent's target (when landing to the node following the corresponding
-            branch.
-
-        Missing/padding nodes are filled in with -inf (truncated).
-        Missing values in present node are filled in with +inf (truncated).
-
-
-        In case of the root node, the values are [0, 0, 0, 0, distance from agent to target].
-        In case the target node is reached, the values are [0, 0, 0, 0, 0].
-        """
-
-        # Update local lookup table for all agents' positions
-        # self.location_has_agent = {}
-        # for loc in self.env.agents_position:
-        #    self.location_has_agent[(loc[0], loc[1])] = 1
-        self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents}
-
-        agent = self.env.agents[handle]  # TODO: handle being treated as index
-        # position = self.env.agents_position[handle]
-        # orientation = self.env.agents_direction[handle]
-        possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
-        num_transitions = np.count_nonzero(possible_transitions)
-        # Root node - current position
-        # observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
-        observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)]]
-        root_observation = observation[:]
-
-        # Start from the current orientation, and see which transitions are available;
-        # organize them as [left, forward, right, back], relative to the current orientation
-        # If only one transition is possible, the tree is oriented with this transition as the forward branch.
-        # TODO: Test if this works as desired!
-        orientation = agent.direction
-        if num_transitions == 1:
-            orientation == np.argmax(possible_transitions)
-
-        # for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]:
-        for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
-            if possible_transitions[branch_direction]:
-                new_cell = self._new_position(agent.position, branch_direction)
-
-                branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1)
-                observation = observation + branch_observation
-            else:
-                num_cells_to_fill_in = 0
-                pow4 = 1
-                for i in range(self.max_depth):
-                    num_cells_to_fill_in += pow4
-                    pow4 *= 4
-                observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
-        return observation
-
-    def _explore_branch(self, handle, position, direction, root_observation, depth):
-        """
-        Utility function to compute tree-based observations.
-        """
-        # [Recursive branch opened]
-        if depth >= self.max_depth + 1:
-            return []
-
-        # Continue along direction until next switch or
-        # until no transitions are possible along the current direction (i.e., dead-ends)
-        # We treat dead-ends as nodes, instead of going back, to avoid loops
-        exploring = True
-        last_isSwitch = False
-        last_isDeadEnd = False
-        last_isTerminal = False  # wrong cell OR cycle;  either way, we don't want the agent to land here
-        last_isTarget = False
-
-        visited = set()
-
-        # other_agent_encountered = False
-        # other_target_encountered = False
-        other_agent_encountered = np.inf
-        other_target_encountered = np.inf
-
-        num_steps = 1
-        while exploring:
-            # #############################
-            # #############################
-            # Modify here to compute any useful data required to build the end node's features. This code is called
-            # for each cell visited between the previous branching node and the next switch / target / dead-end.
-            if position in self.location_has_agent:
-                # other_agent_encountered = True
-                if num_steps < other_agent_encountered:
-                    other_agent_encountered = num_steps
-
-            if position in self.location_has_target:
-                # other_target_encountered = True
-                if num_steps < other_target_encountered:
-                    other_target_encountered = num_steps
-            # #############################
-            # #############################
-
-            if (position[0], position[1], direction) in visited:
-                last_isTerminal = True
-                break
-            visited.add((position[0], position[1], direction))
-
-            # If the target node is encountered, pick that as node. Also, no further branching is possible.
-            # if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
-            if np.array_equal(position, self.env.agents[handle].target):
-                last_isTarget = True
-                break
-
-            cell_transitions = self.env.rail.get_transitions((*position, direction))
-            num_transitions = np.count_nonzero(cell_transitions)
-            exploring = False
-            if num_transitions == 1:
-                # Check if dead-end, or if we can go forward along direction
-                nbits = 0
-                tmp = self.env.rail.get_transitions(tuple(position))
-                while tmp > 0:
-                    nbits += (tmp & 1)
-                    tmp = tmp >> 1
-                if nbits == 1:
-                    # Dead-end!
-                    last_isDeadEnd = True
-
-                if not last_isDeadEnd:
-                    # Keep walking through the tree along `direction'
-                    exploring = True
-                    direction = np.argmax(cell_transitions)
-                    position = self._new_position(position, direction)
-                    num_steps += 1
-
-            elif num_transitions > 0:
-                # Switch detected
-                last_isSwitch = True
-                break
-
-            elif num_transitions == 0:
-                # Wrong cell type, but let's cover it and treat it as a dead-end, just in case
-                print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0],
-                      position[1], direction)
-                last_isTerminal = True
-                break
-
-        # `position' is either a terminal node or a switch
-
-        observation = []
-
-        # #############################
-        # #############################
-        # Modify here to append new / different features for each visited cell!
-        """
-        if last_isTarget:
-            observation = [0,
-                           1 if other_target_encountered else 0,
-                           1 if other_agent_encountered else 0,
-                           root_observation[3] + num_steps,
-                           0]
-
-        elif last_isTerminal:
-            observation = [0,
-                           1 if other_target_encountered else 0,
-                           1 if other_agent_encountered else 0,
-                           np.inf,
-                           np.inf]
-        else:
-            observation = [0,
-                           1 if other_target_encountered else 0,
-                           1 if other_agent_encountered else 0,
-                           root_observation[3] + num_steps,
-                           self.distance_map[handle, position[0], position[1], direction]]
-        """
-        if last_isTarget:
-            observation = [0,
-                           other_target_encountered,
-                           other_agent_encountered,
-                           root_observation[3] + num_steps,
-                           0]
-
-        elif last_isTerminal:
-            observation = [0,
-                           other_target_encountered,
-                           other_agent_encountered,
-                           np.inf,
-                           np.inf]
-        else:
-            observation = [0,
-                           other_target_encountered,
-                           other_agent_encountered,
-                           root_observation[3] + num_steps,
-                           self.distance_map[handle, position[0], position[1], direction]]
-        # #############################
-        # #############################
-
-        new_root_observation = observation[:]
-        # Start from the current orientation, and see which transitions are available;
-        # organize them as [left, forward, right, back], relative to the current orientation
-        # Get the possible transitions
-        possible_transitions = self.env.rail.get_transitions((*position, direction))
-        for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
-            if last_isDeadEnd and self.env.rail.get_transition((*position, direction),
-                                                               (branch_direction + 2) % 4):
-                # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
-                # it back
-                new_cell = self._new_position(position, (branch_direction + 2) % 4)
-                branch_observation = self._explore_branch(handle,
-                                                          new_cell,
-                                                          (branch_direction + 2) % 4,
-                                                          new_root_observation,
-                                                          depth + 1)
-                observation = observation + branch_observation
-
-            elif last_isSwitch and possible_transitions[branch_direction]:
-                new_cell = self._new_position(position, branch_direction)
-                branch_observation = self._explore_branch(handle,
-                                                          new_cell,
-                                                          branch_direction,
-                                                          new_root_observation,
-                                                          depth + 1)
-                observation = observation + branch_observation
-
-            else:
-                num_cells_to_fill_in = 0
-                pow4 = 1
-                for i in range(self.max_depth - depth):
-                    num_cells_to_fill_in += pow4
-                    pow4 *= 4
-                observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
-
-        return observation
-
-    def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0):
-        """
-        Utility function to pretty-print tree observations returned by this object.
-        """
-        if len(tree) < num_features_per_node:
-            return
-
-        depth = 0
-        tmp = len(tree) / num_features_per_node - 1
-        pow4 = 4
-        while tmp > 0:
-            tmp -= pow4
-            depth += 1
-            pow4 *= 4
-
-        prompt_ = ['L:', 'F:', 'R:', 'B:']
-
-        print("  " * current_depth + prompt, tree[0:num_features_per_node])
-        child_size = (len(tree) - num_features_per_node) // 4
-        for children in range(4):
-            child_tree = tree[(num_features_per_node + children * child_size):
-                              (num_features_per_node + (children + 1) * child_size)]
-            self.util_print_obs_subtree(child_tree,
-                                        num_features_per_node,
-                                        prompt=prompt_[children],
-                                        current_depth=current_depth + 1)
-
-    def split_tree(self, tree, num_features_per_node=5, current_depth=0):
-        """
-
-        :param tree:
-        :param num_features_per_node:
-        :param prompt:
-        :param current_depth:
-        :return:
-        """
-
-        if len(tree) < num_features_per_node:
-            return [], []
-
-        depth = 0
-        tmp = len(tree) / num_features_per_node - 1
-        pow4 = 4
-        while tmp > 0:
-            tmp -= pow4
-            depth += 1
-            pow4 *= 4
-        child_size = (len(tree) - num_features_per_node) // 4
-        tree_data = tree[0:num_features_per_node - 1].tolist()
-        distance_data = [tree[num_features_per_node - 1]]
-        for children in range(4):
-            child_tree = tree[(num_features_per_node + children * child_size):
-                              (num_features_per_node + (children + 1) * child_size)]
-            tmp_tree_data, tmp_distance_data = self.split_tree(child_tree,
-                                                               num_features_per_node,
-                                                               current_depth=current_depth + 1)
-            if len(tmp_tree_data) > 0:
-                tree_data.extend(tmp_tree_data)
-                distance_data.extend(tmp_distance_data)
-        return tree_data, distance_data
-
-
-class GlobalObsForRailEnv(ObservationBuilder):
-    """
-    Gives a global observation of the entire rail environment.
-    The observation is composed of the following elements:
-
-        - transition map array with dimensions (env.height, env.width, 16),
-          assuming 16 bits encoding of transitions.
-
-        - Four 2D arrays containing respectively the position of the given agent,
-          the position of its target, the positions of the other agents and of
-          their target.
-
-        - A 4 elements array with one of encoding of the direction.
-    """
-
-    def __init__(self):
-        super(GlobalObsForRailEnv, self).__init__()
-
-    def reset(self):
-        self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
-        for i in range(self.rail_obs.shape[0]):
-            for j in range(self.rail_obs.shape[1]):
-                self.rail_obs[i, j] = np.array(
-                    list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int)
-
-        # self.targets = np.zeros(self.env.height, self.env.width)
-        # for target_pos in self.env.agents_target:
-        #     self.targets[target_pos] += 1
-
-    def get(self, handle):
-        obs = np.zeros((4, self.env.height, self.env.width))
-        agents = self.env.agents
-        agent = agents[handle]
-
-        agent_pos = agents[handle].position
-        obs[0][agent_pos] += 1
-        obs[1][agent.target] += 1
-
-        for i in range(len(agents)):
-            if i != handle:  # TODO: handle used as index...?
-                agent2 = agents[i]
-                obs[3][agent2.position] += 1
-                obs[2][agent2.target] += 1
-
-        direction = np.zeros(4)
-        direction[agent.direction] = 1
-
-        return self.rail_obs, obs, direction
diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index 49209460dd191f1628aabac49d3ce5b72df7e817..1f1bc1db065057a5ced1c917c5a24e2decebf2fa 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -23,14 +23,21 @@ class EnvAgentStatic(object):
     position = attrib()
     direction = attrib()
     target = attrib()
-    old_direction = attrib(default=None)
+
+    def __init__(self, position, direction, target):
+        self.position = position
+        self.direction = direction
+        self.target = target
 
     @classmethod
     def from_lists(cls, positions, directions, targets):
         """ Create a list of EnvAgentStatics from lists of positions, directions and targets
         """
         return list(starmap(EnvAgentStatic, zip(positions, directions, targets)))
-        
+
+    def to_list(self):
+        return [self.position, self.direction, self.target]
+
 
 @attrs
 class EnvAgent(EnvAgentStatic):
@@ -41,6 +48,15 @@ class EnvAgent(EnvAgentStatic):
         forcing the env to refer to it in the EnvAgentStatic
     """
     handle = attrib(default=None)
+    old_direction = attrib(default=None)
+
+    def __init__(self, position, direction, target, handle, old_direction):
+        super(EnvAgent, self).__init__(position, direction, target)
+        self.handle = handle
+        self.old_direction = old_direction
+
+    def to_list(self):
+        return [self.position, self.direction, self.target, self.handle, self.old_direction]
 
     @classmethod
     def from_static(cls, oStatic):
@@ -56,7 +72,6 @@ class EnvAgent(EnvAgentStatic):
         """
         if handles is None:
             handles = range(len(lEnvAgentStatic))
-            
+
         return [EnvAgent(**oEAS.__dict__, handle=handle)
                 for handle, oEAS in zip(handles, lEnvAgentStatic)]
-
diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py
index a1e46db6cf05f3093c403ecb772bf8bf4bca5469..b58604c6d7ededa28a33d30e87e13777a3cd54ec 100644
--- a/flatland/envs/env_utils.py
+++ b/flatland/envs/env_utils.py
@@ -8,7 +8,7 @@ a GridTransitionMap object.
 import numpy as np
 
 # from flatland.core.env import Environment
-# from flatland.core.env_observation_builder import TreeObsForRailEnv
+# from flatland.envs.observations import TreeObsForRailEnv
 
 # from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
 # from flatland.core.transition_map import GridTransitionMap
diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index e9e2d3da95e6bf56948b6610fa25544efd3fe90c..04e9a8fefac4ada79527f00200dc6fbfa3d7b924 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -1,7 +1,7 @@
 import numpy as np
 
 # from flatland.core.env import Environment
-# from flatland.core.env_observation_builder import TreeObsForRailEnv
+# from flatland.envs.observations import TreeObsForRailEnv
 
 from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e4be0ba49d8405aa420d2bc8e4854f6300e3837
--- /dev/null
+++ b/flatland/envs/observations.py
@@ -0,0 +1,515 @@
+"""
+Collection of environment-specific ObservationBuilder.
+"""
+import numpy as np
+from collections import deque
+
+from flatland.core.env_observation_builder import ObservationBuilder
+
+
+class TreeObsForRailEnv(ObservationBuilder):
+    """
+    TreeObsForRailEnv object.
+
+    This object returns observation vectors for agents in the RailEnv environment.
+    The information is local to each agent and exploits the tree structure of the rail
+    network to simplify the representation of the state of the environment for each agent.
+    """
+
+    def __init__(self, max_depth):
+        self.max_depth = max_depth
+
+    def reset(self):
+        agents = self.env.agents
+        nAgents = len(agents)
+        self.distance_map = np.inf * np.ones(shape=(nAgents,  # self.env.number_of_agents,
+                                                    self.env.height,
+                                                    self.env.width,
+                                                    4))
+        self.max_dist = np.zeros(nAgents)
+
+        # for i in range(nAgents):
+        #     self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i)
+        self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
+
+        # Update local lookup table for all agents' target locations
+        self.location_has_target = {}
+        # for loc in self.env.agents_target:
+        #    self.location_has_target[(loc[0], loc[1])] = 1
+        self.location_has_target = {tuple(agent.target): 1 for agent in agents}
+
+    def _distance_map_walker(self, position, target_nr):
+        """
+        Utility function to compute distance maps from each cell in the rail network (and each possible
+        orientation within it) to each agent's target cell.
+        """
+        # Returns max distance to target, from the farthest away node, while filling in distance_map
+
+        self.distance_map[target_nr, position[0], position[1], :] = 0
+
+        # Fill in the (up to) 4 neighboring nodes
+        # nodes_queue = []  # list of tuples (row, col, direction, distance);
+        # direction is the direction of movement, meaning that at least a possible orientation of an agent
+        # in cell (row,col) allows a movement in direction `direction'
+        nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1))
+
+        # BFS from target `position' to all the reachable nodes in the grid
+        # Stop the search if the target position is re-visited, in any direction
+        visited = set([(position[0], position[1], 0),
+                       (position[0], position[1], 1),
+                       (position[0], position[1], 2),
+                       (position[0], position[1], 3)])
+
+        max_distance = 0
+
+        while nodes_queue:
+            node = nodes_queue.popleft()
+
+            node_id = (node[0], node[1], node[2])
+
+            if node_id not in visited:
+                visited.add(node_id)
+
+                # From the list of possible neighbors that have at least a path to the current node, only keep those
+                # whose new orientation in the current cell would allow a transition to direction node[2]
+                valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2])
+
+                for n in valid_neighbors:
+                    nodes_queue.append(n)
+
+                if len(valid_neighbors) > 0:
+                    max_distance = max(max_distance, node[3] + 1)
+
+        return max_distance
+
+    def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1):
+        """
+        Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
+        minimum distances from each target cell.
+        """
+        neighbors = []
+
+        possible_directions = [0, 1, 2, 3]
+        if enforce_target_direction >= 0:
+            # The agent must land into the current cell with orientation `enforce_target_direction'.
+            # This is only possible if the agent has arrived from the cell in the opposite direction!
+            possible_directions = [(enforce_target_direction + 2) % 4]
+
+        for neigh_direction in possible_directions:
+            new_cell = self._new_position(position, neigh_direction)
+
+            if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width:
+
+                desired_movement_from_new_cell = (neigh_direction + 2) % 4
+
+                """
+                # Is the next cell a dead-end?
+                isNextCellDeadEnd = False
+                nbits = 0
+                tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1]))
+                while tmp > 0:
+                    nbits += (tmp & 1)
+                    tmp = tmp >> 1
+                if nbits == 1:
+                    # Dead-end!
+                    isNextCellDeadEnd = True
+                """
+
+                # Check all possible transitions in new_cell
+                for agent_orientation in range(4):
+                    # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
+                    isValid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
+                                                           desired_movement_from_new_cell)
+
+                    if isValid:
+                        """
+                        # TODO: check that it works with deadends! -- still bugged!
+                        movement = desired_movement_from_new_cell
+                        if isNextCellDeadEnd:
+                            movement = (desired_movement_from_new_cell+2) % 4
+                        """
+                        new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
+                                           current_distance + 1)
+                        neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
+                        self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
+
+        return neighbors
+
+    def _new_position(self, position, movement):
+        """
+        Utility function that converts a compass movement over a 2D grid to new positions (r, c).
+        """
+        if movement == 0:  # NORTH
+            return (position[0] - 1, position[1])
+        elif movement == 1:  # EAST
+            return (position[0], position[1] + 1)
+        elif movement == 2:  # SOUTH
+            return (position[0] + 1, position[1])
+        elif movement == 3:  # WEST
+            return (position[0], position[1] - 1)
+
+    def get(self, handle):
+        """
+        Computes the current observation for agent `handle' in env
+
+        The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
+        movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
+        The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for
+        the transitions. The order is:
+            [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']
+
+
+
+
+
+        Each branch data is organized as:
+            [root node information] +
+            [recursive branch data from 'left'] +
+            [... from 'forward'] +
+            [... from 'right] +
+            [... from 'back']
+
+        Finally, each node information is composed of 5 floating point values:
+
+        #1:
+
+        #2: 1 if a target of another agent is detected between the previous node and the current one.
+
+        #3: 1 if another agent is detected between the previous node and the current one.
+
+        #4: distance of agent to the current branch node
+
+        #5: minimum distance from node to the agent's target (when landing to the node following the corresponding
+            branch.
+
+        Missing/padding nodes are filled in with -inf (truncated).
+        Missing values in present node are filled in with +inf (truncated).
+
+
+        In case of the root node, the values are [0, 0, 0, 0, distance from agent to target].
+        In case the target node is reached, the values are [0, 0, 0, 0, 0].
+        """
+
+        # Update local lookup table for all agents' positions
+        # self.location_has_agent = {}
+        # for loc in self.env.agents_position:
+        #    self.location_has_agent[(loc[0], loc[1])] = 1
+        self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents}
+
+        agent = self.env.agents[handle]  # TODO: handle being treated as index
+        # position = self.env.agents_position[handle]
+        # orientation = self.env.agents_direction[handle]
+        possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
+        num_transitions = np.count_nonzero(possible_transitions)
+        # Root node - current position
+        # observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
+        observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)]]
+        root_observation = observation[:]
+
+        # Start from the current orientation, and see which transitions are available;
+        # organize them as [left, forward, right, back], relative to the current orientation
+        # If only one transition is possible, the tree is oriented with this transition as the forward branch.
+        # TODO: Test if this works as desired!
+        orientation = agent.direction
+        if num_transitions == 1:
+            orientation == np.argmax(possible_transitions)
+
+        # for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]:
+        for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
+            if possible_transitions[branch_direction]:
+                new_cell = self._new_position(agent.position, branch_direction)
+
+                branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1)
+                observation = observation + branch_observation
+            else:
+                num_cells_to_fill_in = 0
+                pow4 = 1
+                for i in range(self.max_depth):
+                    num_cells_to_fill_in += pow4
+                    pow4 *= 4
+                observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
+        return observation
+
+    def _explore_branch(self, handle, position, direction, root_observation, depth):
+        """
+        Utility function to compute tree-based observations.
+        """
+        # [Recursive branch opened]
+        if depth >= self.max_depth + 1:
+            return []
+
+        # Continue along direction until next switch or
+        # until no transitions are possible along the current direction (i.e., dead-ends)
+        # We treat dead-ends as nodes, instead of going back, to avoid loops
+        exploring = True
+        last_isSwitch = False
+        last_isDeadEnd = False
+        last_isTerminal = False  # wrong cell OR cycle;  either way, we don't want the agent to land here
+        last_isTarget = False
+
+        visited = set()
+
+        # other_agent_encountered = False
+        # other_target_encountered = False
+        other_agent_encountered = np.inf
+        other_target_encountered = np.inf
+
+        num_steps = 1
+        while exploring:
+            # #############################
+            # #############################
+            # Modify here to compute any useful data required to build the end node's features. This code is called
+            # for each cell visited between the previous branching node and the next switch / target / dead-end.
+            if position in self.location_has_agent:
+                # other_agent_encountered = True
+                if num_steps < other_agent_encountered:
+                    other_agent_encountered = num_steps
+
+            if position in self.location_has_target:
+                # other_target_encountered = True
+                if num_steps < other_target_encountered:
+                    other_target_encountered = num_steps
+            # #############################
+            # #############################
+
+            if (position[0], position[1], direction) in visited:
+                last_isTerminal = True
+                break
+            visited.add((position[0], position[1], direction))
+
+            # If the target node is encountered, pick that as node. Also, no further branching is possible.
+            # if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
+            if np.array_equal(position, self.env.agents[handle].target):
+                last_isTarget = True
+                break
+
+            cell_transitions = self.env.rail.get_transitions((*position, direction))
+            num_transitions = np.count_nonzero(cell_transitions)
+            exploring = False
+            if num_transitions == 1:
+                # Check if dead-end, or if we can go forward along direction
+                nbits = 0
+                tmp = self.env.rail.get_transitions(tuple(position))
+                while tmp > 0:
+                    nbits += (tmp & 1)
+                    tmp = tmp >> 1
+                if nbits == 1:
+                    # Dead-end!
+                    last_isDeadEnd = True
+
+                if not last_isDeadEnd:
+                    # Keep walking through the tree along `direction'
+                    exploring = True
+                    direction = np.argmax(cell_transitions)
+                    position = self._new_position(position, direction)
+                    num_steps += 1
+
+            elif num_transitions > 0:
+                # Switch detected
+                last_isSwitch = True
+                break
+
+            elif num_transitions == 0:
+                # Wrong cell type, but let's cover it and treat it as a dead-end, just in case
+                print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0],
+                      position[1], direction)
+                last_isTerminal = True
+                break
+
+        # `position' is either a terminal node or a switch
+
+        observation = []
+
+        # #############################
+        # #############################
+        # Modify here to append new / different features for each visited cell!
+        """
+        if last_isTarget:
+            observation = [0,
+                           1 if other_target_encountered else 0,
+                           1 if other_agent_encountered else 0,
+                           root_observation[3] + num_steps,
+                           0]
+
+        elif last_isTerminal:
+            observation = [0,
+                           1 if other_target_encountered else 0,
+                           1 if other_agent_encountered else 0,
+                           np.inf,
+                           np.inf]
+        else:
+            observation = [0,
+                           1 if other_target_encountered else 0,
+                           1 if other_agent_encountered else 0,
+                           root_observation[3] + num_steps,
+                           self.distance_map[handle, position[0], position[1], direction]]
+        """
+        if last_isTarget:
+            observation = [0,
+                           other_target_encountered,
+                           other_agent_encountered,
+                           root_observation[3] + num_steps,
+                           0]
+
+        elif last_isTerminal:
+            observation = [0,
+                           other_target_encountered,
+                           other_agent_encountered,
+                           np.inf,
+                           np.inf]
+        else:
+            observation = [0,
+                           other_target_encountered,
+                           other_agent_encountered,
+                           root_observation[3] + num_steps,
+                           self.distance_map[handle, position[0], position[1], direction]]
+        # #############################
+        # #############################
+
+        new_root_observation = observation[:]
+        # Start from the current orientation, and see which transitions are available;
+        # organize them as [left, forward, right, back], relative to the current orientation
+        # Get the possible transitions
+        possible_transitions = self.env.rail.get_transitions((*position, direction))
+        for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
+            if last_isDeadEnd and self.env.rail.get_transition((*position, direction),
+                                                               (branch_direction + 2) % 4):
+                # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
+                # it back
+                new_cell = self._new_position(position, (branch_direction + 2) % 4)
+                branch_observation = self._explore_branch(handle,
+                                                          new_cell,
+                                                          (branch_direction + 2) % 4,
+                                                          new_root_observation,
+                                                          depth + 1)
+                observation = observation + branch_observation
+
+            elif last_isSwitch and possible_transitions[branch_direction]:
+                new_cell = self._new_position(position, branch_direction)
+                branch_observation = self._explore_branch(handle,
+                                                          new_cell,
+                                                          branch_direction,
+                                                          new_root_observation,
+                                                          depth + 1)
+                observation = observation + branch_observation
+
+            else:
+                num_cells_to_fill_in = 0
+                pow4 = 1
+                for i in range(self.max_depth - depth):
+                    num_cells_to_fill_in += pow4
+                    pow4 *= 4
+                observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
+
+        return observation
+
+    def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0):
+        """
+        Utility function to pretty-print tree observations returned by this object.
+        """
+        if len(tree) < num_features_per_node:
+            return
+
+        depth = 0
+        tmp = len(tree) / num_features_per_node - 1
+        pow4 = 4
+        while tmp > 0:
+            tmp -= pow4
+            depth += 1
+            pow4 *= 4
+
+        prompt_ = ['L:', 'F:', 'R:', 'B:']
+
+        print("  " * current_depth + prompt, tree[0:num_features_per_node])
+        child_size = (len(tree) - num_features_per_node) // 4
+        for children in range(4):
+            child_tree = tree[(num_features_per_node + children * child_size):
+                              (num_features_per_node + (children + 1) * child_size)]
+            self.util_print_obs_subtree(child_tree,
+                                        num_features_per_node,
+                                        prompt=prompt_[children],
+                                        current_depth=current_depth + 1)
+
+    def split_tree(self, tree, num_features_per_node=5, current_depth=0):
+        """
+
+        :param tree:
+        :param num_features_per_node:
+        :param prompt:
+        :param current_depth:
+        :return:
+        """
+
+        if len(tree) < num_features_per_node:
+            return [], []
+
+        depth = 0
+        tmp = len(tree) / num_features_per_node - 1
+        pow4 = 4
+        while tmp > 0:
+            tmp -= pow4
+            depth += 1
+            pow4 *= 4
+        child_size = (len(tree) - num_features_per_node) // 4
+        tree_data = tree[0:num_features_per_node - 1].tolist()
+        distance_data = [tree[num_features_per_node - 1]]
+        for children in range(4):
+            child_tree = tree[(num_features_per_node + children * child_size):
+                              (num_features_per_node + (children + 1) * child_size)]
+            tmp_tree_data, tmp_distance_data = self.split_tree(child_tree,
+                                                               num_features_per_node,
+                                                               current_depth=current_depth + 1)
+            if len(tmp_tree_data) > 0:
+                tree_data.extend(tmp_tree_data)
+                distance_data.extend(tmp_distance_data)
+        return tree_data, distance_data
+
+
+class GlobalObsForRailEnv(ObservationBuilder):
+    """
+    Gives a global observation of the entire rail environment.
+    The observation is composed of the following elements:
+
+        - transition map array with dimensions (env.height, env.width, 16),
+          assuming 16 bits encoding of transitions.
+
+        - Four 2D arrays containing respectively the position of the given agent,
+          the position of its target, the positions of the other agents and of
+          their target.
+
+        - A 4 elements array with one of encoding of the direction.
+    """
+
+    def __init__(self):
+        super(GlobalObsForRailEnv, self).__init__()
+
+    def reset(self):
+        self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
+        for i in range(self.rail_obs.shape[0]):
+            for j in range(self.rail_obs.shape[1]):
+                self.rail_obs[i, j] = np.array(
+                    list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int)
+
+        # self.targets = np.zeros(self.env.height, self.env.width)
+        # for target_pos in self.env.agents_target:
+        #     self.targets[target_pos] += 1
+
+    def get(self, handle):
+        obs = np.zeros((4, self.env.height, self.env.width))
+        agents = self.env.agents
+        agent = agents[handle]
+
+        agent_pos = agents[handle].position
+        obs[0][agent_pos] += 1
+        obs[1][agent.target] += 1
+
+        for i in range(len(agents)):
+            if i != handle:  # TODO: handle used as index...?
+                agent2 = agents[i]
+                obs[3][agent2.position] += 1
+                obs[2][agent2.target] += 1
+
+        direction = np.zeros(4)
+        direction[agent.direction] = 1
+
+        return self.rail_obs, obs, direction
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index adb40f03b2ee100d8e12cfead660f9600dd9968b..a5248a8fb3b5a7ee2d433aa12c5daca84753e6e9 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -5,10 +5,10 @@ Generator functions are functions that take width, height and num_resets as argu
 a GridTransitionMap object.
 """
 import numpy as np
-import pickle
+import msgpack
 
 from flatland.core.env import Environment
-from flatland.core.env_observation_builder import TreeObsForRailEnv
+from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.generators import random_rail_generator
 from flatland.envs.env_utils import get_new_position
 from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
@@ -157,7 +157,7 @@ class RailEnv(Environment):
         #    self.dones[handle] = False
         self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
         # perhaps dones should be part of each agent.
-        
+
         # Reset the state of the observation builder with the new environment
         self.obs_builder.reset()
 
@@ -215,7 +215,7 @@ class RailEnv(Environment):
                 # 1) transition allows the new_direction in the cell,
                 # 2) the new cell is not empty (case 0),
                 # 3) the cell is free, i.e., no agent is currently in that cell
-                
+
                 # if (
                 #        new_position[1] >= self.width or
                 #        new_position[0] >= self.height or
@@ -226,11 +226,11 @@ class RailEnv(Environment):
                 #     new_cell_isValid = False
 
                 new_cell_isValid = (
-                        np.array_equal(  # Check the new position is still in the grid
-                            new_position,
-                            np.clip(new_position, [0, 0], [self.height-1, self.width-1]))
-                        and  # check the new position has some transitions (ie is not an empty cell)
-                        self.rail.get_transitions(new_position) > 0)
+                    np.array_equal(  # Check the new position is still in the grid
+                        new_position,
+                        np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
+                    and  # check the new position has some transitions (ie is not an empty cell)
+                    self.rail.get_transitions(new_position) > 0)
 
                 # If transition validity hasn't been checked yet.
                 if transition_isValid is None:
@@ -246,7 +246,7 @@ class RailEnv(Environment):
                 # Check the new position is not the same as any of the existing agent positions
                 # (including itself, for simplicity, since it is moving)
                 cell_isFree = not np.any(
-                        np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
+                    np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
 
                 if all([new_cell_isValid, transition_isValid, cell_isFree]):
                     # move and change direction to face the new_direction that was
@@ -278,7 +278,7 @@ class RailEnv(Environment):
         # if num_agents_in_target_position == self.number_of_agents:
         if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
             self.dones["__all__"] = True
-            self.rewards_dict = [0*r+global_reward for r in self.rewards_dict]
+            self.rewards_dict = [0 * r + global_reward for r in self.rewards_dict]
 
         # Reset the step actions (in case some agent doesn't 'register_action'
         # on the next step)
@@ -324,20 +324,37 @@ class RailEnv(Environment):
         # TODO:
         pass
 
-    def save(self, sFilename):
-        dSave = {
-            "grid": self.rail.grid,
-            "agents_static": self.agents_static
-            }
-        with open(sFilename, "wb") as fOut:
-            pickle.dump(dSave, fOut)
-
-    def load(self, sFilename):
-        with open(sFilename, "rb") as fIn:
-            dLoad = pickle.load(fIn)
-            self.rail.grid = dLoad["grid"]
-            self.height, self.width = self.rail.grid.shape
-            self.agents_static = dLoad["agents_static"]
-            self.agents = [None] * self.get_num_agents()
-            self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
-            
+    def get_full_state_msg(self):
+        grid_data = self.rail.grid.tolist()
+        agent_static_data = [agent.to_list() for agent in self.agents_static]
+        agent_data = [agent.to_list() for agent in self.agents]
+        msg_data = {
+            "grid": grid_data,
+            "agents_static": agent_static_data,
+            "agents": agent_data}
+        return msgpack.packb(msg_data, use_bin_type=True)
+
+    def get_agent_state_msg(self):
+        agent_data = [agent.to_list() for agent in self.agents]
+        msg_data = {
+            "agents": agent_data}
+        return msgpack.packb(msg_data, use_bin_type=True)
+
+    def set_full_state_msg(self, msg_data):
+        data = msgpack.unpackb(msg_data, use_list=False)
+        self.rail.grid = np.array(data[b"grid"])
+        self.agents_static = [EnvAgentStatic(d[0], d[1], d[2]) for d in data[b"agents_static"]]
+        self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]]
+        # setup with loaded data
+        self.height, self.width = self.rail.grid.shape
+        # self.agents = [None] * self.get_num_agents()
+        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
+
+    def save(self, filename):
+        with open(filename, "wb") as file_out:
+            file_out.write(self.get_full_state_msg())
+
+    def load(self, filename):
+        with open(filename, "rb") as file_in:
+            load_data = file_in.read()
+            self.set_full_state_msg(load_data)
diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py
index 3c7f47012e383e1e2bca9af120152b65cfa43e81..a33a0116a93b9c1aa25c9da8103925f7c5ed51ae 100644
--- a/flatland/utils/editor.py
+++ b/flatland/utils/editor.py
@@ -17,7 +17,7 @@ import os
 from flatland.envs.rail_env import RailEnv, random_rail_generator
 from flatland.envs.generators import complex_rail_generator
 # from flatland.core.transitions import RailEnvTransitions
-from flatland.core.env_observation_builder import TreeObsForRailEnv
+from flatland.envs.observations import TreeObsForRailEnv
 import flatland.utils.rendertools as rt
 from examples.play_model import Player
 from flatland.envs.env_utils import mirror
@@ -110,8 +110,7 @@ class View(object):
             self.wTab.set_title(i, title)
         self.wTab.children = [
             VBox([self.wDebug, self.wDebug_move]),
-            VBox([self.wRegenMethod, self.wReplaceAgents])
-            ]
+            VBox([self.wRegenMethod, self.wReplaceAgents])]
 
         # Progress bar intended for stepping in the background (not yet working)
         self.wProg_steps = ipywidgets.IntProgress(value=0, min=0, max=20, step=1, description="Step")
@@ -121,21 +120,20 @@ class View(object):
             dict(name="Refresh", method=self.controller.refresh, tip="Redraw only"),
             dict(name="Clear", method=self.controller.clear, tip="Clear rails and agents"),
             dict(name="Reset", method=self.controller.reset,
-                tip="Standard env reset, including regen rail + agents"),
+                 tip="Standard env reset, including regen rail + agents"),
             dict(name="Restart Agents", method=self.controller.restartAgents,
-                tip="Move agents back to start positions"),
+                 tip="Move agents back to start positions"),
             dict(name="Regenerate", method=self.controller.regenerate,
-                tip="Regenerate the rails using the method selected below"),
+                 tip="Regenerate the rails using the method selected below"),
             dict(name="Load", method=self.controller.load),
             dict(name="Save", method=self.controller.save),
             dict(name="Step", method=self.controller.step),
-            dict(name="Run Steps", method=self.controller.start_run),
-        ]
+            dict(name="Run Steps", method=self.controller.start_run)]
 
         self.lwButtons = []
         for dButton in ldButtons:
             wButton = ipywidgets.Button(description=dButton["name"],
-                tooltip=dButton["tip"] if "tip" in dButton else dButton["name"])
+                                        tooltip=dButton["tip"] if "tip" in dButton else dButton["name"])
             wButton.on_click(dButton["method"])
             self.lwButtons.append(wButton)
 
@@ -145,8 +143,7 @@ class View(object):
             self.wSize,
             self.wNAgents,
             self.wProg_steps,
-            self.wTab
-            ])
+            self.wTab])
 
         self.wMain = HBox([self.wImage, self.wVbox_controls])
 
@@ -164,7 +161,7 @@ class View(object):
         with self.wOutput:
             # plt.figure(figsize=(10, 10))
             self.oRT.renderEnv(spacing=False, arrows=False, sRailColor="gray",
-                show=False, iSelectedAgent=self.model.iSelectedAgent)
+                               show=False, iSelectedAgent=self.model.iSelectedAgent)
             img = self.oRT.getImage()
             # plt.clf()
             # plt.close()
@@ -607,7 +604,7 @@ class EditorModel(object):
 
         # Has the user clicked on an existing agent?
         iAgent = self.find_agent_at(rcCell)
-        
+
         if iAgent is None:
             # No
             if self.iSelectedAgent is None:
@@ -656,7 +653,7 @@ class EditorModel(object):
                 # self.log("step ", i)
                 self.step()
                 time.sleep(0.2)
-                wProg_steps.value = i+1   # indicate progress on bar
+                wProg_steps.value = i + 1   # indicate progress on bar
         finally:
             self.thread = None
 
@@ -683,6 +680,3 @@ class EditorModel(object):
                    binTrans,
                    sbinTrans,
                    [sbinTrans[i:(i + 4)] for i in range(0, len(sbinTrans), 4)])
-
-
-    
\ No newline at end of file
diff --git a/flatland/utils/graphics_layer.py b/flatland/utils/graphics_layer.py
index c0d390c34c66567b7ae216fa53e03abb6baf8a06..4cfcc64bffb82f91a0f36822188db297bc1ed37e 100644
--- a/flatland/utils/graphics_layer.py
+++ b/flatland/utils/graphics_layer.py
@@ -54,7 +54,7 @@ class GraphicsLayer(object):
                 color = tuple((gcolor[:3] * 255).astype(int))
         else:
             color = self.tColGrid
-        
+
         if lighten:
             color = tuple([int(255 - (255 - iRGB) / 3) for iRGB in color])
 
@@ -65,6 +65,6 @@ class GraphicsLayer(object):
 
     def setRailAt(self, row, col, binTrans):
         pass
-    
+
     def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut):
         pass
diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py
index 01cc5f03c7d8f015ff24a57c4ba53f65ddc00684..41516fd94737556a4b8abbc7ccfce0fd503a3d6e 100644
--- a/flatland/utils/graphics_pil.py
+++ b/flatland/utils/graphics_pil.py
@@ -39,7 +39,7 @@ class PILGL(GraphicsLayer):
         r = np.sqrt(s)
         gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.nPixCell
         for x, y in gPoints:
-            self.draw.rectangle([(x-r, y-r), (x+r, y+r)], fill=color, outline=color)
+            self.draw.rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color)
 
     def text(self, *args, **kwargs):
         pass
diff --git a/flatland/utils/render_qt.py b/flatland/utils/render_qt.py
index af40c05cbcf3628a9a6bdb370dae4161d70a8a31..ea9613968c473c351849509bfc3e277dd7fe0701 100644
--- a/flatland/utils/render_qt.py
+++ b/flatland/utils/render_qt.py
@@ -111,7 +111,7 @@ class QTSVG(GraphicsLayer):
         self.wWinMain.resize(1000, 1000)
         self.wWinMain.show()
         self.wWinMain.setFocus()
-        
+
         self.track = self.track = Track()
         self.lwTrack = []
         self.zug = Zug()
@@ -152,21 +152,21 @@ class QTSVG(GraphicsLayer):
     def processEvents(self):
         self.app.processEvents()
         time.sleep(0.001)
-    
+
     def clear_rails(self):
         # print("Clear rails: ", len(self.lwTrack))
         for wRail in self.lwTrack:
             self.layout.removeWidget(wRail)
         self.lwTrack = []
         self.clear_agents()
-    
+
     def clear_agents(self):
         # print("Clear Agents: ", len(self.lwAgents))
         for wAgent in self.lwAgents:
             self.layout.removeWidget(wAgent)
         self.lwAgents = []
         self.agents_prev = []
-        
+
     def setRailAt(self, row, col, binTrans):
         if binTrans in self.track.dSvg:
             sSVG = self.track.dSvg[binTrans].to_string()
@@ -204,7 +204,7 @@ class QTSVG(GraphicsLayer):
                     return
 
         # Ensure we have adequate slots in the list lwAgents
-        for i in range(len(self.lwAgents), iAgent+1):
+        for i in range(len(self.lwAgents), iAgent + 1):
             self.lwAgents.append(None)
             self.agents_prev.append(None)
 
@@ -226,7 +226,7 @@ def main2():
     gl = QTGL(10, 10)
     for i in range(10):
         gl.beginFrame()
-        gl.plot([3+i, 4], [-4-i, -5], color="r")
+        gl.plot([3 + i, 4], [-4 - i, -5], color="r")
         gl.endFrame()
         time.sleep(1)
 
diff --git a/flatland/utils/svg.py b/flatland/utils/svg.py
index b219560c00f6a1e08b7384260274101ee5b430e6..32d5631839f43964cf5ff7d94520aacf67db0488 100644
--- a/flatland/utils/svg.py
+++ b/flatland/utils/svg.py
@@ -30,10 +30,10 @@ class SVG(object):
 
     def merge(self, svg2):
         svg3 = svg2.copy()
-        
+
         svg3.renumber_styles(offset=10)
         svg3.eStyle.text = self.eStyle.text + "\n" + svg3.eStyle.text
-        
+
         for child in self.svg.root:
             if not child.tag.endswith("style"):
                 svg3.svg.root.append(child)
@@ -48,12 +48,12 @@ class SVG(object):
             lEl = self.svg.root.xpath("//*[@class='{}']".format(sClass))
             for el in lEl:
                 el.attrib["class"] = "st{}".format(iStyle + offset)
-        
-            sStyle2 = str(iStyle+offset)
+
+            sStyle2 = str(iStyle + offset)
 
             sNewStyle = "\t.st" + sStyle2 + "{" + self.dStyles[sStyle] + "}\n"
             sNewStyles += sNewStyle
-        
+
         self.eStyle.text = sNewStyles
 
     def set_style_color(self, style_name, color):
@@ -63,7 +63,7 @@ class SVG(object):
                 sValue = "fill:#" + "".join([format(col, "#04x")[2:] for col in color]) + ";"
             sNewStyle = "\t.st" + sKey + "{" + sValue + "}\n"
             sNewStyles += sNewStyle
-        
+
         self.eStyle.text = sNewStyles
 
     def set_rotate(self, angle):
@@ -87,7 +87,7 @@ class Zug(object):
         if delta_dir in (0, 2):
             svg = self.svg_straight.copy()
             svg.set_rotate(iDirIn * 90)
-        
+
         if delta_dir == 1:  # bend to right, eg N->E, E->S
             svg = self.svg_curve1.copy()
             svg.set_rotate((iDirIn - 1) * 90)
@@ -124,8 +124,7 @@ class Track(object):
             "NN SS EN SW": "Weiche_vertikal_oben_links.svg",
             "NN SS SE WN": "Weiche_vertikal_oben_rechts.svg",
             "NN SS NW ES": "Weiche_vertikal_unten_links.svg",
-            "NN SS NE WS": "Weiche_vertikal_unten_rechts.svg",
-            }
+            "NN SS NE WS": "Weiche_vertikal_unten_rechts.svg"}
 
         self.dSvg = {}
 
@@ -136,8 +135,7 @@ class Track(object):
         svgBG = SVG("./svg/Background_#91D1DD.svg")
 
         for sTrans, sFile in dFiles.items():
-            
-            svg = SVG("./svg/"+sFile)
+            svg = SVG("./svg/" + sFile)
 
             lTrans16 = ["0"] * 16
             for sTran in sTrans.split(" "):
@@ -165,7 +163,7 @@ class Track(object):
 def main():
     # svg1 = SVG("./svg/Gleis_vertikal.svg")
     # svg2 = SVG("./svg/Zug_1_Weiche_#0091ea.svg")
-    
+
     # svg3 = svg2.merge(svg1)
     # svg3.set_rotate(90)
 
@@ -188,4 +186,3 @@ def main():
 
 if __name__ == "__main__":
     main()
-
diff --git a/requirements_dev.txt b/requirements_dev.txt
index b0dafbffc9d1b507e4c67fdae91b5ed0c0d452f7..69a03322b1aa316db9e406235e4fe2c39b5282f5 100644
--- a/requirements_dev.txt
+++ b/requirements_dev.txt
@@ -19,5 +19,10 @@ matplotlib==3.0.2
 PyQt5==5.12
 Pillow==5.4.1
 
+<<<<<<< HEAD
 svgutils==0.3.1
 
+=======
+msgpack==0.6.1
+svgutils==0.3.1
+>>>>>>> 19fac96634f312ac3c014c0517fad0c1fe273655
diff --git a/tests/test_environments.py b/tests/test_environments.py
index f12dfa3d6b57f76ce490c2c748fefc008ba371a5..4c55eac7afb44d95f0e49d665eeeb4bc36becea9 100644
--- a/tests/test_environments.py
+++ b/tests/test_environments.py
@@ -4,6 +4,7 @@ import numpy as np
 
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.generators import rail_from_GridTransitionMap_generator
+from flatland.envs.generators import complex_rail_generator
 from flatland.core.transitions import Grid4Transitions
 from flatland.core.transition_map import GridTransitionMap
 from flatland.core.env_observation_builder import GlobalObsForRailEnv
@@ -12,6 +13,30 @@ from flatland.envs.agent_utils import EnvAgent
 """Tests for `flatland` package."""
 
 
+def test_save_load():
+    env = RailEnv(width=10, height=10,
+                  rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=0),
+                  number_of_agents=2)
+    env.reset()
+    agent_1_pos = env.agents_static[0].position
+    agent_1_dir = env.agents_static[0].direction
+    agent_1_tar = env.agents_static[0].target
+    agent_2_pos = env.agents_static[1].position
+    agent_2_dir = env.agents_static[1].direction
+    agent_2_tar = env.agents_static[1].target
+    env.save("test_save.dat")
+    env.load("test_save.dat")
+    assert(env.width == 10)
+    assert(env.height == 10)
+    assert(len(env.agents) == 2)
+    assert(agent_1_pos == env.agents_static[0].position)
+    assert(agent_1_dir == env.agents_static[0].direction)
+    assert(agent_1_tar == env.agents_static[0].target)
+    assert(agent_2_pos == env.agents_static[1].position)
+    assert(agent_2_dir == env.agents_static[1].direction)
+    assert(agent_2_tar == env.agents_static[1].target)
+
+
 def test_rail_environment_single_agent():
 
     cells = [int('0000000000000000', 2),  # empty cell - Case 0
@@ -205,4 +230,4 @@ def test_dead_end():
 
 if __name__ == "__main__":
     test_rail_environment_single_agent()
-    test_dead_end()
\ No newline at end of file
+    test_dead_end()
diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py
index ea4ad3c46f3465c2f30aabaf8ae9cff7e2efd7f7..f6defb2de400f1a007da237b2f91ec38df5db07b 100644
--- a/tests/test_rendertools.py
+++ b/tests/test_rendertools.py
@@ -6,13 +6,16 @@ Tests for `flatland` package.
 
 from flatland.envs.rail_env import RailEnv, random_rail_generator
 import numpy as np
-import os
+<<<<<<< HEAD
+=======
+# import os
+>>>>>>> dc2fa1ee0244b15c76d89ab768c5e1bbd2716147
 import sys
 
 import matplotlib.pyplot as plt
 
 import flatland.utils.rendertools as rt
-from flatland.core.env_observation_builder import TreeObsForRailEnv
+from flatland.envs.observations import TreeObsForRailEnv
 
 
 def checkFrozenImage(oRT, sFileImage, resave=False):
@@ -49,7 +52,7 @@ def test_render_env(save_new_images=False):
     oEnv.rail.load_transition_map(sfTestEnv)
     oRT = rt.RenderTool(oEnv)
     oRT.renderEnv()
-    
+
     checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images)
 
     oRT = rt.RenderTool(oEnv, gl="PIL")
@@ -86,4 +89,4 @@ def main():
 
 
 if __name__ == "__main__":
-    main()
\ No newline at end of file
+    main()