From 236eba12e1b01561a99e357e5894dabbec220bba Mon Sep 17 00:00:00 2001
From: spiglerg <spiglerg@gmail.com>
Date: Fri, 19 Apr 2019 12:28:23 +0200
Subject: [PATCH] initial (working) tree observations, railenv+generators moved
 to envs/rail_env.py, fixes

---
 examples/temporary_example.py            |  23 +-
 flatland/core/env.py                     | 328 -----------
 flatland/core/env_observation_builder.py | 460 ++++++++-------
 flatland/envs/__init__.py                |   0
 flatland/envs/rail_env.py                | 676 +++++++++++++++++++++++
 flatland/utils/rail_env_generator.py     | 341 ------------
 flatland/utils/rendertools.py            |   2 +
 tests/test_environments.py               |  12 +-
 tests/test_rendertools.py                |   5 +-
 9 files changed, 948 insertions(+), 899 deletions(-)
 create mode 100644 flatland/envs/__init__.py
 create mode 100644 flatland/envs/rail_env.py
 delete mode 100644 flatland/utils/rail_env_generator.py

diff --git a/examples/temporary_example.py b/examples/temporary_example.py
index 8b53b1dd..96acb30f 100644
--- a/examples/temporary_example.py
+++ b/examples/temporary_example.py
@@ -2,20 +2,22 @@ import random
 import numpy as np
 import matplotlib.pyplot as plt
 
-from flatland.core.env import RailEnv
-from flatland.utils.rail_env_generator import *
+from flatland.envs.rail_env import *
+from flatland.core.env_observation_builder import TreeObsForRailEnv
 from flatland.utils.rendertools import *
 
 random.seed(1)
 np.random.seed(1)
 
+
 # Example generate a random rail
-env = RailEnv(width=20, height=20, rail_generator=generate_random_rail, number_of_agents=10)
+env = RailEnv(width=20, height=20, rail_generator=random_rail_generator, number_of_agents=10)
 env.reset()
 
 env_renderer = RenderTool(env)
 env_renderer.renderEnv(show=True)
 
+
 # Example generate a rail given a manual specification,
 # a map of tuples (cell_type, rotation)
 specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
@@ -23,13 +25,12 @@ specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
 
 env = RailEnv(width=6,
               height=2,
-              rail_generator=generate_rail_from_manual_specifications(specs),
-              number_of_agents=1)
+              rail_generator=rail_from_manual_specifications_generator(specs),
+              number_of_agents=1,
+              obs_builder_object=TreeObsForRailEnv(max_depth=1))
 
 handle = env.get_agent_handles()
 
-obs = env.reset()
-
 env.agents_position = [[1, 4]]
 env.agents_target = [[1, 1]]
 env.agents_direction = [1]
@@ -37,13 +38,15 @@ env.agents_direction = [1]
 env.obs_builder.reset()
 
 # TODO: delete next line
-#print(env.obs_builder.distance_map[0,:,:])
-#print(env.obs_builder.max_dist)
+#for i in range(4):
+#    print(env.obs_builder.distance_map[0, :, :, i])
+
+obs, all_rewards, done, _ = env.step({0:0})
+env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
 
 env_renderer = RenderTool(env)
 env_renderer.renderEnv(show=True)
 
-
 print("Manual control: s=perform step, q=quit, [agent id] [1-2-3 action] \
        (turnleft+move, move to front, turnright+move)")
 for step in range(100):
diff --git a/flatland/core/env.py b/flatland/core/env.py
index 02d912a3..284afdff 100644
--- a/flatland/core/env.py
+++ b/flatland/core/env.py
@@ -3,10 +3,6 @@ The env module defines the base Environment class.
 The base Environment class is adapted from rllib.env.MultiAgentEnv
 (https://github.com/ray-project/ray).
 """
-import random
-
-from .env_observation_builder import TreeObsForRailEnv
-from flatland.utils.rail_env_generator import generate_random_rail
 
 
 class Environment:
@@ -94,327 +90,3 @@ class Environment:
         function.
         """
         raise NotImplementedError()
-
-
-class RailEnv:
-    """
-    RailEnv environment class.
-
-    RailEnv is an environment inspired by a (simplified version of) a rail
-    network, in which agents (trains) have to navigate to their target
-    locations in the shortest time possible, while at the same time cooperating
-    to avoid bottlenecks.
-
-    The valid actions in the environment are:
-        0: do nothing
-        1: turn left and move to the next cell
-        2: move to the next cell in front of the agent
-        3: turn right and move to the next cell
-
-    Moving forward in a dead-end cell makes the agent turn 180 degrees and step
-    to the cell it came from.
-
-    The actions of the agents are executed in order of their handle to prevent
-    deadlocks and to allow them to learn relative priorities.
-
-    TODO: WRITE ABOUT THE REWARD FUNCTION, and possibly allow for alpha and
-    beta to be passed as parameters to __init__().
-    """
-
-    def __init__(self,
-                 width,
-                 height,
-                 rail_generator=generate_random_rail,
-                 number_of_agents=1,
-                 obs_builder_object=TreeObsForRailEnv(max_depth=2)):
-        """
-        Environment init.
-
-        Parameters
-        -------
-        rail_generator : function
-            The rail_generator function is a function that takes the width and
-            height of a  rail map along with the number of times the env has
-            been reset, and returns a GridTransitionMap object.
-            Implemented functions are:
-                generate_random_rail : generate a random rail of given size
-                TODO: generate_rail_from_saved_list ---
-        width : int
-            The width of the rail map. Potentially in the future,
-            a range of widths to sample from.
-        height : int
-            The height of the rail map. Potentially in the future,
-            a range of heights to sample from.
-        number_of_agents : int
-            Number of agents to spawn on the map. Potentially in the future,
-            a range of number of agents to sample from.
-        obs_builder_object: ObservationBuilder object
-            ObservationBuilder-derived object that takes builds observation
-            vectors for each agent.
-        """
-
-        self.rail_generator = rail_generator
-        self.num_resets = 0
-        self.rail = None
-        self.width = width
-        self.height = height
-
-        self.number_of_agents = number_of_agents
-
-        self.obs_builder = obs_builder_object
-        self.obs_builder.set_env(self)
-
-        self.actions = [0]*self.number_of_agents
-        self.rewards = [0]*self.number_of_agents
-        self.done = False
-
-        self.agents_position = []
-        self.agents_target = []
-        self.agents_direction = []
-
-        self.dones = {"__all__": False}
-        self.obs_dict = {}
-        self.rewards_dict = {}
-
-        self.agents_handles = list(range(self.number_of_agents))
-
-    def get_agent_handles(self):
-        return self.agents_handles
-
-    def reset(self):
-        self.rail = self.rail_generator(self.width, self.height, self.num_resets)
-        self.num_resets += 1
-
-        self.dones = {"__all__": False}
-        for handle in self.agents_handles:
-            self.dones[handle] = False
-
-        re_generate = True
-        while re_generate:
-            valid_positions = []
-            for r in range(self.height):
-                for c in range(self.width):
-                    if self.rail.get_transitions((r, c)) > 0:
-                        valid_positions.append((r, c))
-
-            self.agents_position = random.sample(valid_positions,
-                                                 self.number_of_agents)
-            self.agents_target = random.sample(valid_positions,
-                                               self.number_of_agents)
-
-            # agents_direction must be a direction for which a solution is
-            # guaranteed.
-            self.agents_direction = [0]*self.number_of_agents
-            re_generate = False
-            for i in range(self.number_of_agents):
-                valid_movements = []
-                for direction in range(4):
-                    position = self.agents_position[i]
-                    moves = self.rail.get_transitions(
-                            (position[0], position[1], direction))
-                    for move_index in range(4):
-                        if moves[move_index]:
-                            valid_movements.append((direction, move_index))
-
-                valid_starting_directions = []
-                for m in valid_movements:
-                    new_position = self._new_position(self.agents_position[i],
-                                                      m[1])
-                    if m[0] not in valid_starting_directions and \
-                       self._path_exists(new_position, m[0],
-                                         self.agents_target[i]):
-                        valid_starting_directions.append(m[0])
-
-                if len(valid_starting_directions) == 0:
-                    re_generate = True
-                else:
-                    self.agents_direction[i] = random.sample(
-                                               valid_starting_directions, 1)[0]
-
-        # Reset the state of the observation builder with the new environment
-        self.obs_builder.reset()
-
-        # Return the new observation vectors for each agent
-        return self._get_observations()
-
-    def step(self, action_dict):
-        alpha = 1.0
-        beta = 1.0
-
-        invalid_action_penalty = -2
-        step_penalty = -1 * alpha
-        global_reward = 1 * beta
-
-        # Reset the step rewards
-        self.rewards_dict = {}
-        for handle in self.agents_handles:
-            self.rewards_dict[handle] = 0
-
-        if self.dones["__all__"]:
-            return self._get_observations(), self.rewards_dict, self.dones, {}
-
-        for i in range(len(self.agents_handles)):
-            handle = self.agents_handles[i]
-
-            if handle not in action_dict:
-                continue
-
-            action = action_dict[handle]
-
-            if action < 0 or action > 3:
-                print('ERROR: illegal action=', action,
-                      'for agent with handle=', handle)
-                return
-
-            if action > 0:
-                pos = self.agents_position[i]
-                direction = self.agents_direction[i]
-
-                movement = direction
-                if action == 1:
-                    movement = direction - 1
-                elif action == 3:
-                    movement = direction + 1
-
-                if movement < 0:
-                    movement += 4
-                if movement >= 4:
-                    movement -= 4
-
-                is_deadend = False
-                if action == 2:
-                    # compute number of possible transitions in the current
-                    # cell
-                    nbits = 0
-                    tmp = self.rail.get_transitions((pos[0], pos[1]))
-                    while tmp > 0:
-                        nbits += (tmp & 1)
-                        tmp = tmp >> 1
-                    if nbits == 1:
-                        # dead-end;  assuming the rail network is consistent,
-                        # this should match the direction the agent has come
-                        # from. But it's better to check in any case.
-                        reverse_direction = 0
-                        if direction == 0:
-                            reverse_direction = 2
-                        elif direction == 1:
-                            reverse_direction = 3
-                        elif direction == 2:
-                            reverse_direction = 0
-                        elif direction == 3:
-                            reverse_direction = 1
-
-                        valid_transition = self.rail.get_transition(
-                                            (pos[0], pos[1], direction),
-                                            reverse_direction)
-                        if valid_transition:
-                            direction = reverse_direction
-                            movement = reverse_direction
-                            is_deadend = True
-
-                new_position = self._new_position(pos, movement)
-
-                # Is it a legal move?  1) transition allows the movement in the
-                # cell,  2) the new cell is not empty (case 0),  3) the cell is
-                # free, i.e., no agent is currently in that cell
-                if new_position[1] >= self.width or\
-                   new_position[0] >= self.height or\
-                   new_position[0] < 0 or new_position[1] < 0:
-                    new_cell_isValid = False
-
-                elif self.rail.get_transitions((new_position[0], new_position[1])) > 0:
-                    new_cell_isValid = True
-                else:
-                    new_cell_isValid = False
-
-                transition_isValid = self.rail.get_transition(
-                     (pos[0], pos[1], direction),
-                     movement) or is_deadend
-
-                cell_isFree = True
-                for j in range(self.number_of_agents):
-                    if self.agents_position[j] == new_position:
-                        cell_isFree = False
-                        break
-
-                if new_cell_isValid and transition_isValid and cell_isFree:
-                    # move and change direction to face the movement that was
-                    # performed
-                    self.agents_position[i] = new_position
-                    self.agents_direction[i] = movement
-                else:
-                    # the action was not valid, add penalty
-                    self.rewards_dict[handle] += invalid_action_penalty
-
-            # if agent is not in target position, add step penalty
-            if self.agents_position[i][0] == self.agents_target[i][0] and \
-               self.agents_position[i][1] == self.agents_target[i][1]:
-                self.dones[handle] = True
-            else:
-                self.rewards_dict[handle] += step_penalty
-
-        # Check for end of episode + add global reward to all rewards!
-        num_agents_in_target_position = 0
-        for i in range(self.number_of_agents):
-            if self.agents_position[i][0] == self.agents_target[i][0] and \
-               self.agents_position[i][1] == self.agents_target[i][1]:
-                num_agents_in_target_position += 1
-
-        if num_agents_in_target_position == self.number_of_agents:
-            self.dones["__all__"] = True
-            self.rewards_dict = [r+global_reward for r in self.rewards_dict]
-
-        # Reset the step actions (in case some agent doesn't 'register_action'
-        # on the next step)
-        self.actions = [0]*self.number_of_agents
-
-        return self._get_observations(), self.rewards_dict, self.dones, {}
-
-    def _new_position(self, position, movement):
-        if movement == 0:    # NORTH
-            return (position[0]-1, position[1])
-        elif movement == 1:  # EAST
-            return (position[0], position[1] + 1)
-        elif movement == 2:  # SOUTH
-            return (position[0]+1, position[1])
-        elif movement == 3:  # WEST
-            return (position[0], position[1] - 1)
-
-    def _path_exists(self, start, direction, end):
-        # BFS - Check if a path exists between the 2 nodes
-
-        visited = set()
-        stack = [(start, direction)]
-        while stack:
-            node = stack.pop()
-            if node[0][0] == end[0] and node[0][1] == end[1]:
-                return 1
-            if node not in visited:
-                visited.add(node)
-                moves = self.rail.get_transitions((node[0][0], node[0][1], node[1]))
-                for move_index in range(4):
-                    if moves[move_index]:
-                        stack.append((self._new_position(node[0], move_index),
-                                      move_index))
-
-                # If cell is a dead-end, append previous node with reversed
-                # orientation!
-                nbits = 0
-                tmp = self.rail.get_transitions((node[0][0], node[0][1]))
-                while tmp > 0:
-                    nbits += (tmp & 1)
-                    tmp = tmp >> 1
-                if nbits == 1:
-                    stack.append((node[0], (node[1] + 2) % 4))
-
-        return 0
-
-    def _get_observations(self):
-        self.obs_dict = {}
-        for handle in self.agents_handles:
-            self.obs_dict[handle] = self.obs_builder.get(handle)
-        return self.obs_dict
-
-    def render(self):
-        # TODO:
-        pass
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index f8341e27..4d6da5b5 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -1,3 +1,13 @@
+"""
+ObservationBuilder objects are objects that can be passed to environments designed for customizability.
+The ObservationBuilder-derived custom classes implement 2 functions, reset() and get() or get(handle).
+
++ Reset() is called after each environment reset, to allow for pre-computing relevant data.
+
++ Get() is called whenever an observation has to be computed, potentially for each agent independently in
+case of multi-agent environments.
+"""
+
 import numpy as np
 from collections import deque
 
@@ -5,47 +15,82 @@ from collections import deque
 
 
 class ObservationBuilder:
-    def __init__(self, env):
+    """
+    ObservationBuilder base class.
+    """
+    def __init__(self):
+        pass
+
+    def _set_env(self, env):
         self.env = env
 
     def reset(self):
+        """
+        Called after each environment reset.
+        """
         raise NotImplementedError()
 
-    def get(self, handle):
+    def get(self, handle=0):
+        """
+        Called whenever an observation has to be computed for the `env' environment, possibly
+        for each agent independently (agent id `handle').
+
+        Parameters
+        -------
+        handle : int (optional)
+            Handle of the agent for which to compute the observation vector.
+
+        Returns
+        -------
+        function
+            An observation structure, specific to the corresponding environment.
+        """
         raise NotImplementedError()
 
 
 class TreeObsForRailEnv(ObservationBuilder):
-    def __init__(self, env):
-        self.env = env
+    """
+    TreeObsForRailEnv object.
+
+    This object returns observation vectors for agents in the RailEnv environment.
+    The information is local to each agent and exploits the tree structure of the rail
+    network to simplify the representation of the state of the environment for each agent.
+    """
+    def __init__(self, max_depth):
+        self.max_depth = max_depth
 
     def reset(self):
         self.distance_map = np.inf * np.ones(shape=(self.env.number_of_agents,
                                                     self.env.height,
-                                                    self.env.width))
+                                                    self.env.width,
+                                                    4))
         self.max_dist = np.zeros(self.env.number_of_agents)
 
         for i in range(self.env.number_of_agents):
             self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i)
 
-
     def _distance_map_walker(self, position, target_nr):
+        """
+        Utility function to compute distance maps from each cell in the rail network (and each possible
+        orientation within it) to each agent's target cell.
+        """
         # Returns max distance to target, from the farthest away node, while filling in distance_map
 
         for ori in range(4):
-            self.distance_map[target_nr, position[0], position[1]] = 0
+            self.distance_map[target_nr, position[0], position[1], ori] = 0
 
         # Fill in the (up to) 4 neighboring nodes
         # nodes_queue = []  # list of tuples (row, col, direction, distance);
-        # direction is the direction of movement, meaning that at least a possible orientation
-        # of an agent in cell (row,col) allows a movement in direction `direction'
-        nodes_queue = deque(self._get_and_update_neighbors(position,
-                                                           target_nr, 0, enforce_target_direction=-1))
+        # direction is the direction of movement, meaning that at least a possible orientation of an agent
+        # in cell (row,col) allows a movement in direction `direction'
+        nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1))
 
         # BFS from target `position' to all the reachable nodes in the grid
         # Stop the search if the target position is re-visited, in any direction
-        visited = set([(position[0], position[1], 0), (position[0], position[1], 1),
-                       (position[0], position[1], 2), (position[0], position[1], 3)])
+        visited = set([(position[0], position[1], 0),
+                       (position[0], position[1], 1),
+                       (position[0], position[1], 2),
+                       (position[0], position[1], 3)])
 
         max_distance = 0
 
@@ -54,35 +99,34 @@ class TreeObsForRailEnv(ObservationBuilder):
 
             node_id = (node[0], node[1], node[2])
 
-            #print(node_id, visited, (node_id in visited))
-            #print(nodes_queue)
-
             if node_id not in visited:
                 visited.add(node_id)
 
-                # From the list of possible neighbors that have at least a path to the
-                # current node, only keep those whose new orientation in the current cell
-                # would allow a transition to direction node[2]
-                valid_neighbors = self._get_and_update_neighbors(
-                    (node[0], node[1]), target_nr, node[3], node[2])
+                # From the list of possible neighbors that have at least a path to the current node, only keep those
+                # whose new orientation in the current cell would allow a transition to direction node[2]
+                valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2])
 
                 for n in valid_neighbors:
                     nodes_queue.append(n)
 
-                if len(valid_neighbors)>0:
+                if len(valid_neighbors) > 0:
                     max_distance = max(max_distance, node[3]+1)
 
         return max_distance
 
-
     def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1):
+        """
+        Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
+        minimum distances from each target cell.
+        """
         neighbors = []
 
         for direction in range(4):
-            new_cell = self._new_position(position, (direction+2)%4)
+            new_cell = self._new_position(position, (direction+2) % 4)
+
+            if new_cell[0] >= 0 and new_cell[0] < self.env.height and \
+               new_cell[1] >= 0 and new_cell[1] < self.env.width:
 
-            if new_cell[0]>=0 and new_cell[0]<self.env.height and\
-                new_cell[1]>=0 and new_cell[1]<self.env.width:
                 # Check if the two cells are connected by a valid transition
                 transitionValid = False
                 for orientation in range(4):
@@ -94,17 +138,16 @@ class TreeObsForRailEnv(ObservationBuilder):
                 if not transitionValid:
                     continue
 
-                # Check if a transition in direction node[2] is possible if an agent
-                # lands in the current cell with orientation `direction'; this only
-                # applies to cells that are not dead-ends!
+                # Check if a transition in direction node[2] is possible if an agent lands in the current
+                # cell with orientation `direction'; this only applies to cells that are not dead-ends!
                 directionMatch = True
-                if enforce_target_direction>=0:
-                    directionMatch = self.env.rail.get_transition(
-                        (new_cell[0], new_cell[1], direction), enforce_target_direction)
+                if enforce_target_direction >= 0:
+                    directionMatch = self.env.rail.get_transition((new_cell[0], new_cell[1], direction),
+                                                                  enforce_target_direction)
 
-                # If transition is found to invalid, check if perhaps it
-                # is a dead-end, in which case the direction of movement is rotated
-                # 180 degrees (moving forward turns the agents and makes it step in the previous cell)
+                # If transition is found to invalid, check if perhaps it is a dead-end, in which case the
+                # direction of movement is rotated 180 degrees (moving forward turns the agents and makes
+                # it step in the previous cell)
                 if not directionMatch:
                     # If cell is a dead-end, append previous node with reversed
                     # orientation!
@@ -115,20 +158,28 @@ class TreeObsForRailEnv(ObservationBuilder):
                         tmp = tmp >> 1
                     if nbits == 1:
                         # Dead-end!
-                        # Check if transition is possible in new_cell
-                        # with orientation (direction+2)%4 in direction `direction'
-                        directionMatch = directionMatch or self.env.rail.get_transition(
-                            (new_cell[0], new_cell[1], (direction+2)%4), direction)
+                        # Check if transition is possible in new_cell with orientation
+                        # (direction+2)%4 in direction `direction'
+                        directionMatch = directionMatch or \
+                                         self.env.rail.get_transition((new_cell[0], new_cell[1], (direction+2) % 4),
+                                                                      direction)
 
                 if transitionValid and directionMatch:
-                    new_distance = min(self.distance_map[target_nr,
-                                                         new_cell[0], new_cell[1]], current_distance+1)
-                    neighbors.append((new_cell[0], new_cell[1], direction, new_distance))
-                    self.distance_map[target_nr, new_cell[0], new_cell[1]] = new_distance
+                    # Append all possible orientations in new_cell that allow a transition to direction!
+                    for orientation in range(4):
+                        moves = self.env.rail.get_transitions((new_cell[0], new_cell[1], orientation))
+                        if moves[direction]:
+                            new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], orientation],
+                                               current_distance+1)
+                            neighbors.append((new_cell[0], new_cell[1], orientation, new_distance))
+                            self.distance_map[target_nr, new_cell[0], new_cell[1], orientation] = new_distance
 
         return neighbors
 
     def _new_position(self, position, movement):
+        """
+        Utility function that converts a compass movement over a 2D grid to new positions (r, c).
+        """
         if movement == 0:    # NORTH
             return (position[0]-1, position[1])
         elif movement == 1:  # EAST
@@ -138,202 +189,187 @@ class TreeObsForRailEnv(ObservationBuilder):
         elif movement == 3:  # WEST
             return (position[0], position[1] - 1)
 
-
     def get(self, handle):
-        # TODO: compute the observation for agent `handle'
-        return []
+        """
+        Computes the current observation for agent `handle' in env
 
+        The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
+        movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
+        The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for
+        the transitions. The order is:
+            [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']
 
+        Each branch data is organized as:
+            [root node information] +
+            [recursive branch data from 'left'] +
+            [... from 'forward'] +
+            [... from 'right] +
+            [... from 'back']
 
-"""
+        Finally, each node information is composed of 5 floating point values:
 
-    def get_observation(self, agent):
-        # Get the current observation for an agent
-        current_position = self.internal_position[agent]
-        #target_heading = self._compass(agent).tolist()
-        coordinate = tuple(np.transpose(self._position_to_coordinate([current_position])))
-        agent_distance = self.distance_map[agent][coordinate][0]
-        # Start tree search
-        if current_position == self.target[agent]:
-            agent_tree = Node(current_position, [-np.inf, -np.inf, -np.inf, -np.inf, -1])
-        else:
-            agent_tree = Node(current_position, [0, 0, 0, 0, agent_distance])
-
-        initial_tree_state = Tree_State(agent, current_position, -1, 0, 0)
-        self._tree_search(initial_tree_state, agent_tree, agent)
-        observation = []
-        distance_data = []
+        #1:
 
-        self._flatten_tree(agent_tree, observation, distance_data,  self.max_depth+1)
-        # This is probably very slow!!!!
-        #max_obs = np.max([i for i in observation if i < np.inf])
-        #if max_obs != 0:
-        #    observation = np.array(observation)/ max_obs
+        #2:
 
-        #print([i for i in distance_data if i >= 0])
-        observation = np.concatenate((observation, distance_data))
-        #observation = np.concatenate((observation, np.identity(5)[int(self.last_action[agent])]))
-        #return np.clip(observation / self.max_dist[agent], -1, 1)
-        return np.clip(observation / 15., -1, 1)
+        #3:
 
+        #4:
 
+        #5: minimum distance from node to the agent's target
 
+        Missing/padding nodes are filled in with -inf (truncated).
+        Missing values in present node are filled in with +inf (truncated).
 
-    def _tree_search(self, in_tree_state, parent_node, agent):
-        if in_tree_state.depth >= self.max_depth:
-            return
-        target_distance = np.inf
-        other_target = np.inf
-        other_agent = np.inf
-        coordinate = tuple(np.transpose(self._position_to_coordinate([in_tree_state.position])))
-        curr_target_dist = self.distance_map[agent][coordinate][0]
-        forbidden_action = (in_tree_state.direction + 2) % 4
-        # Update the position
-        failed_move = 0
-        leaf_distance = in_tree_state.distance
-        for child_idx in range(4):
-            if child_idx != forbidden_action or in_tree_state.direction == -1:
-                tree_state = copy.deepcopy(in_tree_state)
-                tree_state.direction = child_idx
-                current_position, invalid_move = self._detect_path(
-                tree_state.position, tree_state.direction)
-                if tree_state.initial_direction == None:
-                    tree_state.initial_direction = child_idx
-                if not invalid_move:
-                    coordinate = tuple(np.transpose(self._position_to_coordinate([current_position])))
-                    curr_target_dist = self.distance_map[agent][coordinate][0]
-                    #if tree_state.initial_direction == None:
-                    #    tree_state.initial_direction = child_idx
-                    tree_state.position = current_position
-                    tree_state.distance += 1
-
-
-                    # Collect information at the current position
-                    detection_distance = tree_state.distance
-                    if current_position == self.target[tree_state.agent]:
-                        target_distance = detection_distance
-
-                    elif current_position in self.target:
-                        other_target = detection_distance
-
-                    if current_position in self.internal_position:
-                        other_agent = detection_distance
-
-                    tree_state.data[0] = self._min_greater_zero(target_distance, tree_state.data[0])
-                    tree_state.data[1] = self._min_greater_zero(other_target, tree_state.data[1])
-                    tree_state.data[2] = self._min_greater_zero(other_agent, tree_state.data[2])
-                    tree_state.data[3] = tree_state.distance
-                    tree_state.data[4] = self._min_greater_zero(curr_target_dist, tree_state.data[4])
-
-                    if self._switch_detection(tree_state.position):
-                        tree_state.depth += 1
-                        new_tree_state = copy.deepcopy(tree_state)
-                        new_node = parent_node.insert(tree_state.position,
-                         tree_state.data, tree_state.initial_direction)
-                        new_tree_state.initial_direction = None
-                        new_tree_state.data = [np.inf, np.inf, np.inf, np.inf, np.inf]
-                        self._tree_search(new_tree_state, new_node, agent)
-                    else:
-                        self._tree_search(tree_state, parent_node, agent)
-                else:
-                    failed_move += 1
-            if failed_move == 3 and in_tree_state.direction != -1:
-                tree_state.data[4] = self._min_greater_zero(curr_target_dist, tree_state.data[4])
-                parent_node.insert(tree_state.position, tree_state.data, tree_state.initial_direction)
-                return
-        return
-
-    def _flatten_tree(self, node, observation_vector, distance_sensor, depth):
-        if depth <= 0:
-            return
-        if node != None:
-            observation_vector.extend(node.data[:-1])
-            distance_sensor.extend([node.data[-1]])
-        else:
-            observation_vector.extend([-np.inf, -np.inf, -np.inf, -np.inf])
-            distance_sensor.extend([-np.inf])
-        for child_idx in range(4):
-            if node != None:
-                child = node.children[child_idx]
+
+        In case of the root node, the values are [0, 0, 0, 0, distance from agent to target].
+        In case the target node is reached, the values are [0, 0, 0, 0, 0].
+        """
+
+        position = self.env.agents_position[handle]
+        orientation = self.env.agents_direction[handle]
+
+        # Root node - current position
+        observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
+
+        # Start from the current orientation, and see which transitions are available;
+        # organize them as [left, forward, right, back], relative to the current orientation
+        for branch_direction in [(orientation+4+i) % 4 for i in range(-1, 3)]:
+            if self.env.rail.get_transition((position[0], position[1], orientation), branch_direction):
+                new_cell = self._new_position(position, branch_direction)
+
+                branch_observation = self._explore_branch(handle, new_cell, branch_direction, 1)
+                observation = observation + branch_observation
             else:
-                child = None
-            self._flatten_tree(child, observation_vector, distance_sensor,  depth -1)
+                num_cells_to_fill_in = 0
+                pow4 = 1
+                for i in range(self.max_depth):
+                    num_cells_to_fill_in += pow4
+                    pow4 *= 4
+                observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf]*num_cells_to_fill_in
 
+        return observation
 
+    def _explore_branch(self, handle, position, direction, depth):
+        """
+        Utility function to compute tree-based observations.
+        """
+        # [Recursive branch opened]
+        if depth >= self.max_depth+1:
+            return []
+
+        # Continue along direction until next switch or
+        # until no transitions are possible along the current direction (i.e., dead-ends)
+        # We treat dead-ends as nodes, instead of going back, to avoid loops
+        exploring = True
+        # TODO: last_isSwitch = False
+        # TODO: last_isTerminal = False  # dead-end
+        # TODO: last_isTarget = False
+        while exploring:
+            # #############################
+            # #############################
+            # Modify here to compute any useful data required to build the end node's features. This code is called
+            # for each cell visited between the previous branching node and the next switch / target / dead-end.
+
+            # TODO: update the current variables according to the current cell in the path
+            # (store info about other agents and targets)
+
+            # TODO: [[[for efficiency, [make dict for hashed-lookup of coords] -- do it in the reset function!]]]
+
+            # #############################
+            # #############################
+
+            # If the target node is encountered, pick that as node. Also, no further branching is possible.
+            if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
+                # TODO: last_isTarget = True
+                break
+
+            cell_transitions = self.env.rail.get_transitions((position[0], position[1], direction))
+            num_transitions = 0
+            for i in range(4):
+                if cell_transitions[i]:
+                    num_transitions += 1
+
+            exploring = False
+            if num_transitions == 1:
+                # Check if dead-end, or if we can go forward along direction
+                if cell_transitions[direction]:
+                    position = self._new_position(position, direction)
+
+                    # Keep walking through the tree along `direction'
+                    exploring = True
 
-    def _switch_detection(self, position):
-        # Hack to detect switches
-        # This can later directly be derived from the transition matrix
-        paths = 0
-        for i in range(4):
-            _, invalid_move = self._detect_path(position, i)
-            if not invalid_move:
-                paths +=1
-            if paths >= 3:
-                return True
-        return False
+                else:
+                    # If a dead-end is reached, pick that as node. Also, no further branching is possible.
+                    # TODO: last_isTerminal = True
+                    break
 
+            elif num_transitions > 0:
+                # Switch detected
+                # TODO: last_isSwitch = True
+                break
 
+            elif num_transitions == 0:
+                # Wrong cell type, but let's cover it and treat it as a dead-end, just in case
+                # TODO: last_isTerminal = True
+                break
 
+        # `position' is either a terminal node or a switch
 
-    def _min_greater_zero(self, x, y):
-        if x <= 0 and y <= 0:
-            return 0
-        if x < 0:
-            return y
-        if y < 0:
-            return x
-        return min(x, y)
+        observation = []
 
+        # #############################
+        # #############################
+        # Modify here to append new / different features for each visited cell!
 
+        observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], direction]]
+        # TODO:
 
-"""
+        # #############################
+        # #############################
 
+        # Start from the current orientation, and see which transitions are available;
+        # organize them as [left, forward, right, back], relative to the current orientation
+        for branch_direction in [(direction+4+i) % 4 for i in range(-1, 3)]:
+            if self.env.rail.get_transition((position[0], position[1], direction), branch_direction):
+                new_cell = self._new_position(position, branch_direction)
 
-class Tree_State:
-    """
-    Keep track of the current state while building the tree
-    """
-    def __init__(self, agent, position, direction, depth, distance):
-        self.agent = agent
-        self.position = position
-        self.direction = direction
-        self.depth = depth
-        self.initial_direction = None
-        self.distance = distance
-        self.data = [np.inf, np.inf, np.inf, np.inf, np.inf]
-
-class Node():
-    """
-    Define a tree node to get populated during search
-    """
-    def __init__(self, position, data):
-        self.n_children = 4
-        self.children = [None]*self.n_children
-        self.data = data
-        self.position = position
+                branch_observation = self._explore_branch(handle, new_cell, branch_direction, depth+1)
+                observation = observation + branch_observation
 
-    def insert(self, position, data, child_idx):
-        """
-        Insert new node with data
+            else:
+                num_cells_to_fill_in = 0
+                pow4 = 1
+                for i in range(self.max_depth-depth):
+                    num_cells_to_fill_in += pow4
+                    pow4 *= 4
+                observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf]*num_cells_to_fill_in
 
-        @param data node data object to insert
-        """
-        new_node = Node(position, data)
-        self.children[child_idx] = new_node
-        return new_node
+        return observation
 
-    def print_tree(self, i=0, depth = 0):
+    def util_print_obs_subtree(self, tree, num_elements_per_node=5, prompt='', current_depth=0):
         """
-        Print tree content inorder
+        Utility function to pretty-print tree observations returned by this object.
         """
-        current_i = i
-        curr_depth = depth+1
-        if i < self.n_children:
-            if self.children[i] != None:
-                self.children[i].print_tree(depth=curr_depth)
-            current_i += 1
-            if self.children[i] != None:
-                self.children[i].print_tree(i, depth=curr_depth)
-
+        if len(tree) < num_elements_per_node:
+            return
 
+        depth = 0
+        tmp = len(tree)/num_elements_per_node-1
+        pow4 = 4
+        while tmp > 0:
+            tmp -= pow4
+            depth += 1
+            pow4 *= 4
+
+        prompt_ = ['L:', 'F:', 'R:', 'B:']
+
+        print("  "*current_depth + prompt, tree[0:num_elements_per_node])
+        child_size = (len(tree)-num_elements_per_node)//4
+        for children in range(4):
+            child_tree = tree[(num_elements_per_node+children*child_size):
+                              (num_elements_per_node+(children+1)*child_size)]
+            self.util_print_obs_subtree(child_tree,
+                                        num_elements_per_node,
+                                        prompt=prompt_[children],
+                                        current_depth=current_depth+1)
diff --git a/flatland/envs/__init__.py b/flatland/envs/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
new file mode 100644
index 00000000..8b34bf3d
--- /dev/null
+++ b/flatland/envs/rail_env.py
@@ -0,0 +1,676 @@
+"""
+Definition of the RailEnv environment and related level-generation functions.
+
+Generator functions are functions that take width, height and num_resets as arguments and return
+a GridTransitionMap object.
+"""
+import random
+import numpy as np
+
+from flatland.core.env import Environment
+from flatland.core.env_observation_builder import TreeObsForRailEnv
+
+from flatland.core.transitions import RailEnvTransitions
+from flatland.core.transition_map import GridTransitionMap
+
+
+def rail_from_manual_specifications_generator(rail_spec):
+    """
+    Utility to convert a rail given by manual specification as a map of tuples
+    (cell_type, rotation), to a transition map with the correct 16-bit
+    transitions specifications.
+
+    Parameters
+    -------
+    rail_spec : list of list of tuples
+        List (rows) of lists (columns) of tuples, each specifying a cell for
+        the RailEnv environment as (cell_type, rotation), with rotation being
+        clock-wise and in [0, 90, 180, 270].
+
+    Returns
+    -------
+    function
+        Generator function that always returns a GridTransitionMap object with
+        the matrix of correct 16-bit bitmaps for each cell.
+    """
+    def generator(width, height, num_resets=0):
+        t_utils = RailEnvTransitions()
+
+        height = len(rail_spec)
+        width = len(rail_spec[0])
+        rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
+
+        for r in range(height):
+            for c in range(width):
+                cell = rail_spec[r][c]
+                if cell[0] < 0 or cell[0] >= len(t_utils.transitions):
+                    print("ERROR - invalid cell type=", cell[0])
+                    return []
+                rail.set_transitions((r, c), t_utils.rotate_transition(
+                              t_utils.transitions[cell[0]], cell[1]))
+
+        return rail
+
+    return generator
+
+
+def rail_from_GridTransitionMap_generator(rail_map):
+    """
+    Utility to convert a rail given by a GridTransitionMap map with the correct
+    16-bit transitions specifications.
+
+    Parameters
+    -------
+    rail_map : GridTransitionMap object
+        GridTransitionMap object to return when the generator is called.
+
+    Returns
+    -------
+    function
+        Generator function that always returns the given `rail_map' object.
+    """
+    def generator(width, height, num_resets=0):
+        return rail_map
+
+    return generator
+
+
+"""
+def generate_rail_from_list_of_manual_specifications(list_of_specifications)
+    def generator(width, height, num_resets=0):
+        return generate_rail_from_manual_specifications(list_of_specifications)
+
+    return generator
+"""
+
+
+def random_rail_generator(width, height, num_resets=0):
+    """
+    Dummy random level generator:
+    - fill in cells at random in [width-2, height-2]
+    - keep filling cells in among the unfilled ones, such that all transitions
+      are legit;  if no cell can be filled in without violating some
+      transitions, pick one among those that can satisfy most transitions
+      (1,2,3 or 4), and delete (+mark to be re-filled) the cells that were
+      incompatible.
+    - keep trying for a total number of insertions
+      (e.g., (W-2)*(H-2)*MAX_REPETITIONS ); if no solution is found, empty the
+      board and try again from scratch.
+    - finally pad the border of the map with dead-ends to avoid border issues.
+
+    Dead-ends are not allowed inside the grid, only at the border; however, if
+    no cell type can be inserted in a given cell (because of the neighboring
+    transitions), deadends are allowed if they solve the problem. This was
+    found to turn most un-genereatable levels into valid ones.
+
+    Parameters
+    -------
+    width : int
+        The width (number of cells) of the grid to generate.
+    height : int
+        The height (number of cells) of the grid to generate.
+
+    Returns
+    -------
+    numpy.ndarray of type numpy.uint16
+        The matrix with the correct 16-bit bitmaps for each cell.
+    """
+
+    t_utils = RailEnvTransitions()
+
+    transitions_templates_ = []
+    for i in range(len(t_utils.transitions)-1):  # don't include dead-ends
+        all_transitions = 0
+        for dir_ in range(4):
+            trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
+            all_transitions |= (trans[0] << 3) | \
+                               (trans[1] << 2) | \
+                               (trans[2] << 1) | \
+                               (trans[3])
+
+        template = [int(x) for x in bin(all_transitions)[2:]]
+        template = [0]*(4-len(template)) + template
+
+        # add all rotations
+        for rot in [0, 90, 180, 270]:
+            transitions_templates_.append((template,
+                                          t_utils.rotate_transition(
+                                           t_utils.transitions[i],
+                                           rot)))
+            template = [template[-1]]+template[:-1]
+
+    def get_matching_templates(template):
+        ret = []
+        for i in range(len(transitions_templates_)):
+            is_match = True
+            for j in range(4):
+                if template[j] >= 0 and \
+                   template[j] != transitions_templates_[i][0][j]:
+                    is_match = False
+                    break
+            if is_match:
+                ret.append(transitions_templates_[i][1])
+        return ret
+
+    MAX_INSERTIONS = (width-2) * (height-2) * 10
+    MAX_ATTEMPTS_FROM_SCRATCH = 10
+
+    attempt_number = 0
+    while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH:
+        cells_to_fill = []
+        rail = []
+        for r in range(height):
+            rail.append([None]*width)
+            if r > 0 and r < height-1:
+                cells_to_fill = cells_to_fill \
+                                + [(r, c) for c in range(1, width-1)]
+
+        num_insertions = 0
+        while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0:
+            cell = random.sample(cells_to_fill, 1)[0]
+            cells_to_fill.remove(cell)
+            row = cell[0]
+            col = cell[1]
+
+            # look at its neighbors and see what are the possible transitions
+            # that can be chosen from, if any.
+            valid_template = [-1, -1, -1, -1]
+
+            for el in [(0, 2, (-1, 0)),
+                       (1, 3, (0, 1)),
+                       (2, 0, (1, 0)),
+                       (3, 1, (0, -1))]:  # N, E, S, W
+                neigh_trans = rail[row+el[2][0]][col+el[2][1]]
+                if neigh_trans is not None:
+                    # select transition coming from facing direction el[1] and
+                    # moving to direction el[1]
+                    max_bit = 0
+                    for k in range(4):
+                        max_bit |= \
+                         t_utils.get_transition(neigh_trans, k, el[1])
+
+                    if max_bit:
+                        valid_template[el[0]] = 1
+                    else:
+                        valid_template[el[0]] = 0
+
+            possible_cell_transitions = get_matching_templates(valid_template)
+
+            if len(possible_cell_transitions) == 0:  # NO VALID TRANSITIONS
+                # no cell can be filled in without violating some transitions
+                # can a dead-end solve the problem?
+                if valid_template.count(1) == 1:
+                    for k in range(4):
+                        if valid_template[k] == 1:
+                            rot = 0
+                            if k == 0:
+                                rot = 180
+                            elif k == 1:
+                                rot = 270
+                            elif k == 2:
+                                rot = 0
+                            elif k == 3:
+                                rot = 90
+
+                            rail[row][col] = t_utils.rotate_transition(
+                                              int('0000000000100000', 2), rot)
+                            num_insertions += 1
+
+                            break
+
+                else:
+                    # can I get valid transitions by removing a single
+                    # neighboring cell?
+                    bestk = -1
+                    besttrans = []
+                    for k in range(4):
+                        tmp_template = valid_template[:]
+                        tmp_template[k] = -1
+                        possible_cell_transitions = get_matching_templates(
+                                                     tmp_template)
+                        if len(possible_cell_transitions) > len(besttrans):
+                            besttrans = possible_cell_transitions
+                            bestk = k
+
+                    if bestk >= 0:
+                        # Replace the corresponding cell with None, append it
+                        # to cells to fill, fill in a transition in the current
+                        # cell.
+                        replace_row = row - 1
+                        replace_col = col
+                        if bestk == 1:
+                            replace_row = row
+                            replace_col = col + 1
+                        elif bestk == 2:
+                            replace_row = row + 1
+                            replace_col = col
+                        elif bestk == 3:
+                            replace_row = row
+                            replace_col = col - 1
+
+                        cells_to_fill.append((replace_row, replace_col))
+                        rail[replace_row][replace_col] = None
+
+                        rail[row][col] = random.sample(
+                                                     besttrans, 1)[0]
+                        num_insertions += 1
+
+                    else:
+                        print('WARNING: still nothing!')
+                        rail[row][col] = int('0000000000000000', 2)
+                        num_insertions += 1
+                        pass
+
+            else:
+                rail[row][col] = random.sample(
+                                             possible_cell_transitions, 1)[0]
+                num_insertions += 1
+
+        if num_insertions == MAX_INSERTIONS:
+            # Failed to generate a valid level; try again for a number of times
+            attempt_number += 1
+        else:
+            break
+
+    if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH:
+        print('ERROR: failed to generate level')
+
+    # Finally pad the border of the map with dead-ends to avoid border issues;
+    # at most 1 transition in the neigh cell
+    for r in range(height):
+        # Check for transitions coming from [r][1] to WEST
+        max_bit = 0
+        neigh_trans = rail[r][1]
+        if neigh_trans is not None:
+            for k in range(4):
+                neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
+                                             & (2**4-1)
+                max_bit = max_bit | (neigh_trans_from_direction & 1)
+        if max_bit:
+            rail[r][0] = t_utils.rotate_transition(
+                           int('0000000000100000', 2), 270)
+        else:
+            rail[r][0] = int('0000000000000000', 2)
+
+        # Check for transitions coming from [r][-2] to EAST
+        max_bit = 0
+        neigh_trans = rail[r][-2]
+        if neigh_trans is not None:
+            for k in range(4):
+                neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
+                                             & (2**4-1)
+                max_bit = max_bit | (neigh_trans_from_direction & (1 << 2))
+        if max_bit:
+            rail[r][-1] = t_utils.rotate_transition(int('0000000000100000', 2),
+                                                    90)
+        else:
+            rail[r][-1] = int('0000000000000000', 2)
+
+    for c in range(width):
+        # Check for transitions coming from [1][c] to NORTH
+        max_bit = 0
+        neigh_trans = rail[1][c]
+        if neigh_trans is not None:
+            for k in range(4):
+                neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
+                                              & (2**4-1)
+                max_bit = max_bit | (neigh_trans_from_direction & (1 << 3))
+        if max_bit:
+            rail[0][c] = int('0000000000100000', 2)
+        else:
+            rail[0][c] = int('0000000000000000', 2)
+
+        # Check for transitions coming from [-2][c] to SOUTH
+        max_bit = 0
+        neigh_trans = rail[-2][c]
+        if neigh_trans is not None:
+            for k in range(4):
+                neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
+                                             & (2**4-1)
+                max_bit = max_bit | (neigh_trans_from_direction & (1 << 1))
+        if max_bit:
+            rail[-1][c] = t_utils.rotate_transition(
+                            int('0000000000100000', 2), 180)
+        else:
+            rail[-1][c] = int('0000000000000000', 2)
+
+    # For display only, wrong levels
+    for r in range(height):
+        for c in range(width):
+            if rail[r][c] is None:
+                rail[r][c] = int('0000000000000000', 2)
+
+    tmp_rail = np.asarray(rail, dtype=np.uint16)
+    return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
+    return_rail.grid = tmp_rail
+    return return_rail
+
+
+class RailEnv(Environment):
+    """
+    RailEnv environment class.
+
+    RailEnv is an environment inspired by a (simplified version of) a rail
+    network, in which agents (trains) have to navigate to their target
+    locations in the shortest time possible, while at the same time cooperating
+    to avoid bottlenecks.
+
+    The valid actions in the environment are:
+        0: do nothing
+        1: turn left and move to the next cell
+        2: move to the next cell in front of the agent
+        3: turn right and move to the next cell
+
+    Moving forward in a dead-end cell makes the agent turn 180 degrees and step
+    to the cell it came from.
+
+    The actions of the agents are executed in order of their handle to prevent
+    deadlocks and to allow them to learn relative priorities.
+
+    TODO: WRITE ABOUT THE REWARD FUNCTION, and possibly allow for alpha and
+    beta to be passed as parameters to __init__().
+    """
+
+    def __init__(self,
+                 width,
+                 height,
+                 rail_generator=random_rail_generator,
+                 number_of_agents=1,
+                 obs_builder_object=TreeObsForRailEnv(max_depth=2)):
+        """
+        Environment init.
+
+        Parameters
+        -------
+        rail_generator : function
+            The rail_generator function is a function that takes the width and
+            height of a  rail map along with the number of times the env has
+            been reset, and returns a GridTransitionMap object.
+            Implemented functions are:
+                random_rail_generator : generate a random rail of given size
+                rail_from_GridTransitionMap_generator(rail_map) : generate a rail from
+                                        a GridTransitionMap object
+                rail_from_manual_specifications_generator(rail_spec) : generate a rail from
+                                        a rail specifications array
+                TODO: generate_rail_from_saved_list or from list of ndarray bitmaps ---
+        width : int
+            The width of the rail map. Potentially in the future,
+            a range of widths to sample from.
+        height : int
+            The height of the rail map. Potentially in the future,
+            a range of heights to sample from.
+        number_of_agents : int
+            Number of agents to spawn on the map. Potentially in the future,
+            a range of number of agents to sample from.
+        obs_builder_object: ObservationBuilder object
+            ObservationBuilder-derived object that takes builds observation
+            vectors for each agent.
+        """
+
+        self.rail_generator = rail_generator
+        self.rail = None
+        self.width = width
+        self.height = height
+
+        self.number_of_agents = number_of_agents
+
+        self.obs_builder = obs_builder_object
+        self.obs_builder._set_env(self)
+
+        self.actions = [0]*self.number_of_agents
+        self.rewards = [0]*self.number_of_agents
+        self.done = False
+
+        self.dones = {"__all__": False}
+        self.obs_dict = {}
+        self.rewards_dict = {}
+
+        self.agents_handles = list(range(self.number_of_agents))
+
+        # self.agents_position = []
+        # self.agents_target = []
+        # self.agents_direction = []
+        self.num_resets = 0
+        self.reset()
+        self.num_resets = 0
+
+    def get_agent_handles(self):
+        return self.agents_handles
+
+    def reset(self):
+        self.rail = self.rail_generator(self.width, self.height, self.num_resets)
+        self.num_resets += 1
+
+        self.dones = {"__all__": False}
+        for handle in self.agents_handles:
+            self.dones[handle] = False
+
+        re_generate = True
+        while re_generate:
+            valid_positions = []
+            for r in range(self.height):
+                for c in range(self.width):
+                    if self.rail.get_transitions((r, c)) > 0:
+                        valid_positions.append((r, c))
+
+            self.agents_position = random.sample(valid_positions,
+                                                 self.number_of_agents)
+            self.agents_target = random.sample(valid_positions,
+                                               self.number_of_agents)
+
+            # agents_direction must be a direction for which a solution is
+            # guaranteed.
+            self.agents_direction = [0]*self.number_of_agents
+            re_generate = False
+            for i in range(self.number_of_agents):
+                valid_movements = []
+                for direction in range(4):
+                    position = self.agents_position[i]
+                    moves = self.rail.get_transitions(
+                            (position[0], position[1], direction))
+                    for move_index in range(4):
+                        if moves[move_index]:
+                            valid_movements.append((direction, move_index))
+
+                valid_starting_directions = []
+                for m in valid_movements:
+                    new_position = self._new_position(self.agents_position[i],
+                                                      m[1])
+                    if m[0] not in valid_starting_directions and \
+                       self._path_exists(new_position, m[0],
+                                         self.agents_target[i]):
+                        valid_starting_directions.append(m[0])
+
+                if len(valid_starting_directions) == 0:
+                    re_generate = True
+                else:
+                    self.agents_direction[i] = random.sample(
+                                               valid_starting_directions, 1)[0]
+
+        # Reset the state of the observation builder with the new environment
+        self.obs_builder.reset()
+
+        # Return the new observation vectors for each agent
+        return self._get_observations()
+
+    def step(self, action_dict):
+        alpha = 1.0
+        beta = 1.0
+
+        invalid_action_penalty = -2
+        step_penalty = -1 * alpha
+        global_reward = 1 * beta
+
+        # Reset the step rewards
+        self.rewards_dict = {}
+        for handle in self.agents_handles:
+            self.rewards_dict[handle] = 0
+
+        if self.dones["__all__"]:
+            return self._get_observations(), self.rewards_dict, self.dones, {}
+
+        for i in range(len(self.agents_handles)):
+            handle = self.agents_handles[i]
+
+            if handle not in action_dict:
+                continue
+
+            action = action_dict[handle]
+
+            if action < 0 or action > 3:
+                print('ERROR: illegal action=', action,
+                      'for agent with handle=', handle)
+                return
+
+            if action > 0:
+                pos = self.agents_position[i]
+                direction = self.agents_direction[i]
+
+                movement = direction
+                if action == 1:
+                    movement = direction - 1
+                elif action == 3:
+                    movement = direction + 1
+
+                if movement < 0:
+                    movement += 4
+                if movement >= 4:
+                    movement -= 4
+
+                is_deadend = False
+                if action == 2:
+                    # compute number of possible transitions in the current
+                    # cell
+                    nbits = 0
+                    tmp = self.rail.get_transitions((pos[0], pos[1]))
+                    while tmp > 0:
+                        nbits += (tmp & 1)
+                        tmp = tmp >> 1
+                    if nbits == 1:
+                        # dead-end;  assuming the rail network is consistent,
+                        # this should match the direction the agent has come
+                        # from. But it's better to check in any case.
+                        reverse_direction = 0
+                        if direction == 0:
+                            reverse_direction = 2
+                        elif direction == 1:
+                            reverse_direction = 3
+                        elif direction == 2:
+                            reverse_direction = 0
+                        elif direction == 3:
+                            reverse_direction = 1
+
+                        valid_transition = self.rail.get_transition(
+                                            (pos[0], pos[1], direction),
+                                            reverse_direction)
+                        if valid_transition:
+                            direction = reverse_direction
+                            movement = reverse_direction
+                            is_deadend = True
+
+                new_position = self._new_position(pos, movement)
+
+                # Is it a legal move?  1) transition allows the movement in the
+                # cell,  2) the new cell is not empty (case 0),  3) the cell is
+                # free, i.e., no agent is currently in that cell
+                if new_position[1] >= self.width or\
+                   new_position[0] >= self.height or\
+                   new_position[0] < 0 or new_position[1] < 0:
+                    new_cell_isValid = False
+
+                elif self.rail.get_transitions((new_position[0], new_position[1])) > 0:
+                    new_cell_isValid = True
+                else:
+                    new_cell_isValid = False
+
+                transition_isValid = self.rail.get_transition(
+                     (pos[0], pos[1], direction),
+                     movement) or is_deadend
+
+                cell_isFree = True
+                for j in range(self.number_of_agents):
+                    if self.agents_position[j] == new_position:
+                        cell_isFree = False
+                        break
+
+                if new_cell_isValid and transition_isValid and cell_isFree:
+                    # move and change direction to face the movement that was
+                    # performed
+                    self.agents_position[i] = new_position
+                    self.agents_direction[i] = movement
+                else:
+                    # the action was not valid, add penalty
+                    self.rewards_dict[handle] += invalid_action_penalty
+
+            # if agent is not in target position, add step penalty
+            if self.agents_position[i][0] == self.agents_target[i][0] and \
+               self.agents_position[i][1] == self.agents_target[i][1]:
+                self.dones[handle] = True
+            else:
+                self.rewards_dict[handle] += step_penalty
+
+        # Check for end of episode + add global reward to all rewards!
+        num_agents_in_target_position = 0
+        for i in range(self.number_of_agents):
+            if self.agents_position[i][0] == self.agents_target[i][0] and \
+               self.agents_position[i][1] == self.agents_target[i][1]:
+                num_agents_in_target_position += 1
+
+        if num_agents_in_target_position == self.number_of_agents:
+            self.dones["__all__"] = True
+            self.rewards_dict = [r+global_reward for r in self.rewards_dict]
+
+        # Reset the step actions (in case some agent doesn't 'register_action'
+        # on the next step)
+        self.actions = [0]*self.number_of_agents
+
+        return self._get_observations(), self.rewards_dict, self.dones, {}
+
+    def _new_position(self, position, movement):
+        if movement == 0:    # NORTH
+            return (position[0]-1, position[1])
+        elif movement == 1:  # EAST
+            return (position[0], position[1] + 1)
+        elif movement == 2:  # SOUTH
+            return (position[0]+1, position[1])
+        elif movement == 3:  # WEST
+            return (position[0], position[1] - 1)
+
+    def _path_exists(self, start, direction, end):
+        # BFS - Check if a path exists between the 2 nodes
+
+        visited = set()
+        stack = [(start, direction)]
+        while stack:
+            node = stack.pop()
+            if node[0][0] == end[0] and node[0][1] == end[1]:
+                return 1
+            if node not in visited:
+                visited.add(node)
+                moves = self.rail.get_transitions((node[0][0], node[0][1], node[1]))
+                for move_index in range(4):
+                    if moves[move_index]:
+                        stack.append((self._new_position(node[0], move_index),
+                                      move_index))
+
+                # If cell is a dead-end, append previous node with reversed
+                # orientation!
+                nbits = 0
+                tmp = self.rail.get_transitions((node[0][0], node[0][1]))
+                while tmp > 0:
+                    nbits += (tmp & 1)
+                    tmp = tmp >> 1
+                if nbits == 1:
+                    stack.append((node[0], (node[1] + 2) % 4))
+
+        return 0
+
+    def _get_observations(self):
+        self.obs_dict = {}
+        for handle in self.agents_handles:
+            self.obs_dict[handle] = self.obs_builder.get(handle)
+        return self.obs_dict
+
+    def render(self):
+        # TODO:
+        pass
diff --git a/flatland/utils/rail_env_generator.py b/flatland/utils/rail_env_generator.py
deleted file mode 100644
index ab11ac5e..00000000
--- a/flatland/utils/rail_env_generator.py
+++ /dev/null
@@ -1,341 +0,0 @@
-"""
-The rail_env_generator module defines provides utilities to generate env
-bitmaps for the RailEnv environment.
-"""
-import random
-import numpy as np
-
-from flatland.core.transitions import RailEnvTransitions
-from flatland.core.transition_map import GridTransitionMap
-
-
-def generate_rail_from_manual_specifications(rail_spec):
-    """
-    Utility to convert a rail given by manual specification as a map of tuples
-    (cell_type, rotation), to a transition map with the correct 16-bit
-    transitions specifications.
-
-    Parameters
-    -------
-    rail_spec : list of list of tuples
-        List (rows) of lists (columns) of tuples, each specifying a cell for
-        the RailEnv environment as (cell_type, rotation), with rotation being
-        clock-wise and in [0, 90, 180, 270].
-
-    Returns
-    -------
-    function
-        Generator function that always returns a GridTransitionMap object with
-        the matrix of correct 16-bit bitmaps for each cell.
-    """
-    def generator(width, height, num_resets=0):
-        t_utils = RailEnvTransitions()
-
-        height = len(rail_spec)
-        width = len(rail_spec[0])
-        rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
-
-        for r in range(height):
-            for c in range(width):
-                cell = rail_spec[r][c]
-                if cell[0] < 0 or cell[0] >= len(t_utils.transitions):
-                    print("ERROR - invalid cell type=", cell[0])
-                    return []
-                rail.set_transitions((r, c), t_utils.rotate_transition(
-                              t_utils.transitions[cell[0]], cell[1]))
-
-        return rail
-
-    return generator
-
-
-def generate_rail_from_GridTransitionMap(rail_map):
-    """
-    Utility to convert a rail given by a GridTransitionMap map with the correct
-    16-bit transitions specifications.
-
-    Parameters
-    -------
-    rail_map : GridTransitionMap object
-        GridTransitionMap object to return when the generator is called.
-
-    Returns
-    -------
-    function
-        Generator function that always returns the given `rail_map' object.
-    """
-    def generator(width, height, num_resets=0):
-        return rail_map
-
-    return generator
-
-
-"""
-def generate_rail_from_list_of_manual_specifications(list_of_specifications)
-    def generator(width, height, num_resets=0):
-        return generate_rail_from_manual_specifications(list_of_specifications)
-
-    return generator
-"""
-
-
-def generate_random_rail(width, height, num_resets=0):
-    """
-    Dummy random level generator:
-    - fill in cells at random in [width-2, height-2]
-    - keep filling cells in among the unfilled ones, such that all transitions
-      are legit;  if no cell can be filled in without violating some
-      transitions, pick one among those that can satisfy most transitions
-      (1,2,3 or 4), and delete (+mark to be re-filled) the cells that were
-      incompatible.
-    - keep trying for a total number of insertions
-      (e.g., (W-2)*(H-2)*MAX_REPETITIONS ); if no solution is found, empty the
-      board and try again from scratch.
-    - finally pad the border of the map with dead-ends to avoid border issues.
-
-    Dead-ends are not allowed inside the grid, only at the border; however, if
-    no cell type can be inserted in a given cell (because of the neighboring
-    transitions), deadends are allowed if they solve the problem. This was
-    found to turn most un-genereatable levels into valid ones.
-
-    Parameters
-    -------
-    width : int
-        The width (number of cells) of the grid to generate.
-    height : int
-        The height (number of cells) of the grid to generate.
-
-    Returns
-    -------
-    numpy.ndarray of type numpy.uint16
-        The matrix with the correct 16-bit bitmaps for each cell.
-    """
-
-    t_utils = RailEnvTransitions()
-
-    transitions_templates_ = []
-    for i in range(len(t_utils.transitions)-1):  # don't include dead-ends
-        all_transitions = 0
-        for dir_ in range(4):
-            trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
-            all_transitions |= (trans[0] << 3) | \
-                               (trans[1] << 2) | \
-                               (trans[2] << 1) | \
-                               (trans[3])
-
-        template = [int(x) for x in bin(all_transitions)[2:]]
-        template = [0]*(4-len(template)) + template
-
-        # add all rotations
-        for rot in [0, 90, 180, 270]:
-            transitions_templates_.append((template,
-                                          t_utils.rotate_transition(
-                                           t_utils.transitions[i],
-                                           rot)))
-            template = [template[-1]]+template[:-1]
-
-    def get_matching_templates(template):
-        ret = []
-        for i in range(len(transitions_templates_)):
-            is_match = True
-            for j in range(4):
-                if template[j] >= 0 and \
-                   template[j] != transitions_templates_[i][0][j]:
-                    is_match = False
-                    break
-            if is_match:
-                ret.append(transitions_templates_[i][1])
-        return ret
-
-    MAX_INSERTIONS = (width-2) * (height-2) * 10
-    MAX_ATTEMPTS_FROM_SCRATCH = 10
-
-    attempt_number = 0
-    while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH:
-        cells_to_fill = []
-        rail = []
-        for r in range(height):
-            rail.append([None]*width)
-            if r > 0 and r < height-1:
-                cells_to_fill = cells_to_fill \
-                                + [(r, c) for c in range(1, width-1)]
-
-        num_insertions = 0
-        while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0:
-            cell = random.sample(cells_to_fill, 1)[0]
-            cells_to_fill.remove(cell)
-            row = cell[0]
-            col = cell[1]
-
-            # look at its neighbors and see what are the possible transitions
-            # that can be chosen from, if any.
-            valid_template = [-1, -1, -1, -1]
-
-            for el in [(0, 2, (-1, 0)),
-                       (1, 3, (0, 1)),
-                       (2, 0, (1, 0)),
-                       (3, 1, (0, -1))]:  # N, E, S, W
-                neigh_trans = rail[row+el[2][0]][col+el[2][1]]
-                if neigh_trans is not None:
-                    # select transition coming from facing direction el[1] and
-                    # moving to direction el[1]
-                    max_bit = 0
-                    for k in range(4):
-                        max_bit |= \
-                         t_utils.get_transition(neigh_trans, k, el[1])
-
-                    if max_bit:
-                        valid_template[el[0]] = 1
-                    else:
-                        valid_template[el[0]] = 0
-
-            possible_cell_transitions = get_matching_templates(valid_template)
-
-            if len(possible_cell_transitions) == 0:  # NO VALID TRANSITIONS
-                # no cell can be filled in without violating some transitions
-                # can a dead-end solve the problem?
-                if valid_template.count(1) == 1:
-                    for k in range(4):
-                        if valid_template[k] == 1:
-                            rot = 0
-                            if k == 0:
-                                rot = 180
-                            elif k == 1:
-                                rot = 270
-                            elif k == 2:
-                                rot = 0
-                            elif k == 3:
-                                rot = 90
-
-                            rail[row][col] = t_utils.rotate_transition(
-                                              int('0000000000100000', 2), rot)
-                            num_insertions += 1
-
-                            break
-
-                else:
-                    # can I get valid transitions by removing a single
-                    # neighboring cell?
-                    bestk = -1
-                    besttrans = []
-                    for k in range(4):
-                        tmp_template = valid_template[:]
-                        tmp_template[k] = -1
-                        possible_cell_transitions = get_matching_templates(
-                                                     tmp_template)
-                        if len(possible_cell_transitions) > len(besttrans):
-                            besttrans = possible_cell_transitions
-                            bestk = k
-
-                    if bestk >= 0:
-                        # Replace the corresponding cell with None, append it
-                        # to cells to fill, fill in a transition in the current
-                        # cell.
-                        replace_row = row - 1
-                        replace_col = col
-                        if bestk == 1:
-                            replace_row = row
-                            replace_col = col + 1
-                        elif bestk == 2:
-                            replace_row = row + 1
-                            replace_col = col
-                        elif bestk == 3:
-                            replace_row = row
-                            replace_col = col - 1
-
-                        cells_to_fill.append((replace_row, replace_col))
-                        rail[replace_row][replace_col] = None
-
-                        rail[row][col] = random.sample(
-                                                     besttrans, 1)[0]
-                        num_insertions += 1
-
-                    else:
-                        print('WARNING: still nothing!')
-                        rail[row][col] = int('0000000000000000', 2)
-                        num_insertions += 1
-                        pass
-
-            else:
-                rail[row][col] = random.sample(
-                                             possible_cell_transitions, 1)[0]
-                num_insertions += 1
-
-        if num_insertions == MAX_INSERTIONS:
-            # Failed to generate a valid level; try again for a number of times
-            attempt_number += 1
-        else:
-            break
-
-    if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH:
-        print('ERROR: failed to generate level')
-
-    # Finally pad the border of the map with dead-ends to avoid border issues;
-    # at most 1 transition in the neigh cell
-    for r in range(height):
-        # Check for transitions coming from [r][1] to WEST
-        max_bit = 0
-        neigh_trans = rail[r][1]
-        if neigh_trans is not None:
-            for k in range(4):
-                neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
-                                             & (2**4-1)
-                max_bit = max_bit | (neigh_trans_from_direction & 1)
-        if max_bit:
-            rail[r][0] = t_utils.rotate_transition(
-                           int('0000000000100000', 2), 270)
-        else:
-            rail[r][0] = int('0000000000000000', 2)
-
-        # Check for transitions coming from [r][-2] to EAST
-        max_bit = 0
-        neigh_trans = rail[r][-2]
-        if neigh_trans is not None:
-            for k in range(4):
-                neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
-                                             & (2**4-1)
-                max_bit = max_bit | (neigh_trans_from_direction & (1 << 2))
-        if max_bit:
-            rail[r][-1] = t_utils.rotate_transition(int('0000000000100000', 2),
-                                                    90)
-        else:
-            rail[r][-1] = int('0000000000000000', 2)
-
-    for c in range(width):
-        # Check for transitions coming from [1][c] to NORTH
-        max_bit = 0
-        neigh_trans = rail[1][c]
-        if neigh_trans is not None:
-            for k in range(4):
-                neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
-                                              & (2**4-1)
-                max_bit = max_bit | (neigh_trans_from_direction & (1 << 3))
-        if max_bit:
-            rail[0][c] = int('0000000000100000', 2)
-        else:
-            rail[0][c] = int('0000000000000000', 2)
-
-        # Check for transitions coming from [-2][c] to SOUTH
-        max_bit = 0
-        neigh_trans = rail[-2][c]
-        if neigh_trans is not None:
-            for k in range(4):
-                neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
-                                             & (2**4-1)
-                max_bit = max_bit | (neigh_trans_from_direction & (1 << 1))
-        if max_bit:
-            rail[-1][c] = t_utils.rotate_transition(
-                            int('0000000000100000', 2), 180)
-        else:
-            rail[-1][c] = int('0000000000000000', 2)
-
-    # For display only, wrong levels
-    for r in range(height):
-        for c in range(width):
-            if rail[r][c] is None:
-                rail[r][c] = int('0000000000000000', 2)
-
-    tmp_rail = np.asarray(rail, dtype=np.uint16)
-    return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
-    return_rail.grid = tmp_rail
-    return return_rail
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index d0a78917..5365800b 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -6,6 +6,8 @@ import xarray as xr
 import matplotlib.pyplot as plt
 
 
+# TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv!
+
 class RenderTool(object):
     Visit = recordtype("Visit", ["rc", "iDir", "iDepth", "prev"])
 
diff --git a/tests/test_environments.py b/tests/test_environments.py
index ce9fbd4f..b62a2e60 100644
--- a/tests/test_environments.py
+++ b/tests/test_environments.py
@@ -1,10 +1,9 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
-from flatland.core.env import RailEnv
+from flatland.core.env import RailEnv, rail_from_GridTransitionMap_generator
 from flatland.core.transitions import Grid4Transitions
 from flatland.core.transition_map import GridTransitionMap
-from flatland.utils.rail_env_generator import generate_rail_from_GridTransitionMap
 import numpy as np
 
 """Tests for `flatland` package."""
@@ -47,7 +46,10 @@ def test_rail_environment_single_agent():
 
     rail = GridTransitionMap(width=3, height=3, transitions=transitions)
     rail.grid = rail_map
-    rail_env = RailEnv(width=3, height=3, rail_generator=generate_rail_from_GridTransitionMap(rail), number_of_agents=1)
+    rail_env = RailEnv(width=3,
+                       height=3,
+                       rail_generator=rail_from_GridTransitionMap_generator(rail),
+                       number_of_agents=1)
     for _ in range(200):
         _ = rail_env.reset()
 
@@ -121,7 +123,7 @@ def test_dead_end():
     rail.grid = rail_map
     rail_env = RailEnv(width=rail_map.shape[1],
                        height=rail_map.shape[0],
-                       rail_generator=generate_rail_from_GridTransitionMap(rail),
+                       rail_generator=rail_from_GridTransitionMap_generator(rail),
                        number_of_agents=1)
 
     def check_consistency(rail_env):
@@ -170,7 +172,7 @@ def test_dead_end():
     rail.grid = rail_map
     rail_env = RailEnv(width=rail_map.shape[1],
                        height=rail_map.shape[0],
-                       rail_generator=generate_rail_from_GridTransitionMap(rail),
+                       rail_generator=rail_from_GridTransitionMap_generator(rail),
                        number_of_agents=1)
 
     rail_env.reset()
diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py
index 5fecd085..427ff2f7 100644
--- a/tests/test_rendertools.py
+++ b/tests/test_rendertools.py
@@ -4,14 +4,13 @@
 Tests for `flatland` package.
 """
 
-from flatland.core.env import RailEnv
+from flatland.envs.rail_env import RailEnv, random_rail_generator
 import numpy as np
 import random
 import os
 
 import matplotlib.pyplot as plt
 
-from flatland.utils import rail_env_generator
 import flatland.utils.rendertools as rt
 
 
@@ -37,7 +36,7 @@ def checkFrozenImage(sFileImage):
 
 def test_render_env():
     random.seed(100)
-    oEnv = RailEnv(width=10, height=10, rail_generator=rail_env_generator.generate_random_rail, number_of_agents=2)
+    oEnv = RailEnv(width=10, height=10, rail_generator=random_rail_generator, number_of_agents=2)
     oEnv.reset()
     oRT = rt.RenderTool(oEnv)
     plt.figure(figsize=(10, 10))
-- 
GitLab