diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..71c9cbb42e95985dc7927602d0017479eb75dfc7
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2019 SBB AG
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/examples/sample_10_10_rail.npy b/examples/sample_10_10_rail.npy
new file mode 100644
index 0000000000000000000000000000000000000000..a8dc0d41ecfff0c5c3a8b7446b1dd6246573608e
Binary files /dev/null and b/examples/sample_10_10_rail.npy differ
diff --git a/examples/temporary_example.py b/examples/temporary_example.py
index 2444d3895cf43499ae5108ea16981e55a5989156..02c282cb374914651d063a2b118fb688257e7631 100644
--- a/examples/temporary_example.py
+++ b/examples/temporary_example.py
@@ -1,41 +1,75 @@
-from flatland.core.env import RailEnv
-from flatland.utils.rail_env_generator import *
+import random
+import numpy as np
+import matplotlib.pyplot as plt
+
+from flatland.envs.rail_env import *
+from flatland.core.env_observation_builder import TreeObsForRailEnv
 from flatland.utils.rendertools import *
 
-random.seed(1)
-np.random.seed(1)
+random.seed(0)
+np.random.seed(0)
 
+transition_probability = [1.0,  # empty cell - Case 0
+                          1.0,  # Case 1 - straight
+                          1.0,  # Case 2 - simple switch
+                          0.3,  # Case 3 - diamond drossing
+                          0.5,  # Case 4 - single slip
+                          0.5,  # Case 5 - double slip
+                          0.2,  # Case 6 - symmetrical
+                          0.0]  # Case 7 - dead end
 
 # Example generate a random rail
-rail = generate_random_rail(20, 20)
+env = RailEnv(width=20,
+              height=20,
+              rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
+              number_of_agents=10)
+
+# env = RailEnv(width=20,
+#               height=20,
+#               rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(['examples/sample_10_10_rail.npy']),
+#               number_of_agents=10)
 
-env = RailEnv(rail, number_of_agents=10)
 env.reset()
 
 env_renderer = RenderTool(env)
 env_renderer.renderEnv(show=True)
 
-
+"""
 # Example generate a rail given a manual specification,
 # a map of tuples (cell_type, rotation)
 specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
          [(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)]]
 
-rail = generate_rail_from_manual_specifications(specs)
-env = RailEnv(rail, number_of_agents=1)
+env = RailEnv(width=6,
+              height=2,
+              rail_generator=rail_from_manual_specifications_generator(specs),
+              number_of_agents=1,
+              obs_builder_object=TreeObsForRailEnv(max_depth=2))
 
 handle = env.get_agent_handles()
-
-env.reset()
-
-env.agents_position = [[1, 4]]
-env.agents_target = [[1, 1]]
-env.agents_direction = [1]
+env.agents_position[0] = [1, 4]
+env.agents_target[0] = [1, 1]
+env.agents_direction[0] = 1
+# TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too!
+env.obs_builder.reset()
+"""
+env = RailEnv(width=7,
+              height=7,
+              rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
+              number_of_agents=2)
+
+# Print the distance map of each cell to the target of the first agent
+# for i in range(4):
+#     print(env.obs_builder.distance_map[0, :, :, i])
+
+# Print the observation vector for agent 0
+obs, all_rewards, done, _ = env.step({0:0})
+for i in range(env.number_of_agents):
+    env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=5)
 
 env_renderer = RenderTool(env)
 env_renderer.renderEnv(show=True)
 
-
 print("Manual control: s=perform step, q=quit, [agent id] [1-2-3 action] \
        (turnleft+move, move to front, turnright+move)")
 for step in range(100):
diff --git a/examples/training_navigation.py b/examples/training_navigation.py
new file mode 100644
index 0000000000000000000000000000000000000000..975d33fb3139ebc2040b941dbe73d8aeb3b225eb
--- /dev/null
+++ b/examples/training_navigation.py
@@ -0,0 +1,104 @@
+from flatland.envs.rail_env import *
+from flatland.core.env_observation_builder import TreeObsForRailEnv
+from flatland.utils.rendertools import *
+from flatland.baselines.dueling_double_dqn import Agent
+from collections import deque
+import torch,random
+
+random.seed(1)
+np.random.seed(1)
+
+
+# Example generate a rail given a manual specification,
+# a map of tuples (cell_type, rotation)
+transition_probability = [1.0,  # empty cell - Case 0
+                          1.0,  # Case 1 - straight
+                          1.0,  # Case 2 - simple switch
+                          0.3,  # Case 3 - diamond drossing
+                          0.5,  # Case 4 - single slip
+                          0.5,  # Case 5 - double slip
+                          0.2,  # Case 6 - symmetrical
+                          0.0]  # Case 7 - dead end
+
+# Example generate a random rail
+env = RailEnv(width=7,
+              height=7,
+              rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
+              number_of_agents=1)
+env_renderer = RenderTool(env)
+handle = env.get_agent_handles()
+
+state_size = 105
+action_size = 4
+n_trials = 5000
+eps = 1.
+eps_end = 0.005
+eps_decay = 0.998
+action_dict = dict()
+scores_window = deque(maxlen=100)
+done_window = deque(maxlen=100)
+scores = []
+dones_list = []
+
+agent = Agent(state_size, action_size, "FC", 0)
+
+for trials in range(1, n_trials + 1):
+
+    # Reset environment
+    obs = env.reset()
+    # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
+
+
+    score = 0
+    env_done = 0
+
+    # Run episode
+    for step in range(100):
+        if trials >= 114:
+            env_renderer.renderEnv(show=True)
+
+        # Action
+        for a in range(env.number_of_agents):
+            action = agent.act(np.array(obs[a]), eps=eps)
+            action_dict.update({a: action})
+
+        # Environment step
+        next_obs, all_rewards, done, _ = env.step(action_dict)
+
+
+        # Update replay buffer and train agent
+        for a in range(env.number_of_agents):
+            agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
+            score += all_rewards[a]
+
+        obs = next_obs.copy()
+
+        if done['__all__']:
+            env_done = 1
+            break
+    # Epsioln decay
+    eps = max(eps_end, eps_decay * eps)  # decrease epsilon
+
+    done_window.append(env_done)
+    scores_window.append(score)  # save most recent score
+    scores.append(np.mean(scores_window))
+    dones_list.append((np.mean(done_window)))
+
+    print('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f}'.format(env.number_of_agents,
+                                                                                                             trials,
+                                                                                                             np.mean(
+                                                                                                                 scores_window),
+                                                                                                             100 * np.mean(
+                                                                                                                 done_window),
+                                                                                                             eps),
+          end=" ")
+    if trials % 100 == 0:
+        print(
+            '\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f}'.format(env.number_of_agents,
+                                                                                                               trials,
+                                                                                                               np.mean(
+                                                                                                                   scores_window),
+                                                                                                               100 * np.mean(
+                                                                                                                   done_window),
+                                                                                                               eps))
+        torch.save(agent.qnetwork_local.state_dict(), '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth')
diff --git a/flatland/baselines/__init__.py b/flatland/baselines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/flatland/baselines/dueling_double_dqn.py b/flatland/baselines/dueling_double_dqn.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee75a615cda4d26a06810d2b7d109fe5691d5ac4
--- /dev/null
+++ b/flatland/baselines/dueling_double_dqn.py
@@ -0,0 +1,189 @@
+import numpy as np
+import random
+from collections import namedtuple, deque
+import os
+from flatland.baselines.model import QNetwork, QNetwork2
+import torch
+import torch.nn.functional as F
+import torch.optim as optim
+import copy
+
+BUFFER_SIZE = int(1e5)  # replay buffer size
+BATCH_SIZE = 512  # minibatch size
+GAMMA = 0.99  # discount factor 0.99
+TAU = 1e-3  # for soft update of target parameters
+LR = 0.5e-4  # learning rate 5
+UPDATE_EVERY = 10  # how often to update the network
+double_dqn = True  # If using double dqn algorithm
+input_channels = 5  # Number of Input channels
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+device = torch.device("cpu")
+print(device)
+
+
+class Agent:
+    """Interacts with and learns from the environment."""
+
+    def __init__(self, state_size, action_size, net_type, seed, double_dqn=True, input_channels=5):
+        """Initialize an Agent object.
+
+        Params
+        ======
+            state_size (int): dimension of each state
+            action_size (int): dimension of each action
+            seed (int): random seed
+        """
+        self.state_size = state_size
+        self.action_size = action_size
+        self.seed = random.seed(seed)
+        self.version = net_type
+        self.double_dqn = double_dqn
+        # Q-Network
+        if self.version == "Conv":
+            self.qnetwork_local = QNetwork2(state_size, action_size, seed, input_channels).to(device)
+            self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
+        else:
+            self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device)
+            self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
+
+        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)
+
+        # Replay memory
+        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
+        # Initialize time step (for updating every UPDATE_EVERY steps)
+        self.t_step = 0
+
+    def save(self, filename):
+        torch.save(self.qnetwork_local.state_dict(), filename + ".local")
+        torch.save(self.qnetwork_target.state_dict(), filename + ".target")
+
+    def load(self, filename):
+        if os.path.exists(filename + ".local"):
+            self.qnetwork_local.load_state_dict(torch.load(filename + ".local"))
+        if os.path.exists(filename + ".target"):
+            self.qnetwork_target.load_state_dict(torch.load(filename + ".target"))
+
+    def step(self, state, action, reward, next_state, done):
+        # Save experience in replay memory
+        self.memory.add(state, action, reward, next_state, done)
+
+        # Learn every UPDATE_EVERY time steps.
+        self.t_step = (self.t_step + 1) % UPDATE_EVERY
+        if self.t_step == 0:
+            # If enough samples are available in memory, get random subset and learn
+            if len(self.memory) > BATCH_SIZE:
+                experiences = self.memory.sample()
+                self.learn(experiences, GAMMA)
+
+    def act(self, state, eps=0.):
+        """Returns actions for given state as per current policy.
+
+        Params
+        ======
+            state (array_like): current state
+            eps (float): epsilon, for epsilon-greedy action selection
+        """
+        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
+        self.qnetwork_local.eval()
+        with torch.no_grad():
+            action_values = self.qnetwork_local(state)
+        self.qnetwork_local.train()
+
+        # Epsilon-greedy action selection
+        if random.random() > eps:
+            return np.argmax(action_values.cpu().data.numpy())
+        else:
+            return random.choice(np.arange(self.action_size))
+
+    def learn(self, experiences, gamma):
+
+        """Update value parameters using given batch of experience tuples.
+
+        Params
+        ======
+            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
+            gamma (float): discount factor
+        """
+        states, actions, rewards, next_states, dones = experiences
+
+        # Get expected Q values from local model
+        Q_expected = self.qnetwork_local(states).gather(1, actions)
+
+        if self.double_dqn:
+            # Double DQN
+            q_best_action = self.qnetwork_local(next_states).max(1)[1]
+            Q_targets_next = self.qnetwork_target(next_states).gather(1, q_best_action.unsqueeze(-1))
+        else:
+            # DQN
+            Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(-1)
+
+            # Compute Q targets for current states
+
+        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
+
+        # Compute loss
+        loss = F.mse_loss(Q_expected, Q_targets)
+        # Minimize the loss
+        self.optimizer.zero_grad()
+        loss.backward()
+        self.optimizer.step()
+
+        # ------------------- update target network ------------------- #
+        self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)
+
+    def soft_update(self, local_model, target_model, tau):
+        """Soft update model parameters.
+        θ_target = τ*θ_local + (1 - τ)*θ_target
+
+        Params
+        ======
+            local_model (PyTorch model): weights will be copied from
+            target_model (PyTorch model): weights will be copied to
+            tau (float): interpolation parameter
+        """
+        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
+            target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)
+
+
+class ReplayBuffer:
+    """Fixed-size buffer to store experience tuples."""
+
+    def __init__(self, action_size, buffer_size, batch_size, seed):
+        """Initialize a ReplayBuffer object.
+
+        Params
+        ======
+            action_size (int): dimension of each action
+            buffer_size (int): maximum size of buffer
+            batch_size (int): size of each training batch
+            seed (int): random seed
+        """
+        self.action_size = action_size
+        self.memory = deque(maxlen=buffer_size)
+        self.batch_size = batch_size
+        self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
+        self.seed = random.seed(seed)
+
+    def add(self, state, action, reward, next_state, done):
+        """Add a new experience to memory."""
+        e = self.experience(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done)
+        self.memory.append(e)
+
+    def sample(self):
+        """Randomly sample a batch of experiences from memory."""
+        experiences = random.sample(self.memory, k=self.batch_size)
+
+        states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device)
+        actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device)
+        rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)
+        next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(
+            device)
+        dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(
+            device)
+
+        return (states, actions, rewards, next_states, dones)
+
+    def __len__(self):
+        """Return the current size of internal memory."""
+        return len(self.memory)
diff --git a/flatland/baselines/model.py b/flatland/baselines/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a5b3d613342a4fba8e2c8f1f45df21381e12684
--- /dev/null
+++ b/flatland/baselines/model.py
@@ -0,0 +1,61 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class QNetwork(nn.Module):
+    def __init__(self, state_size, action_size, seed, hidsize1=128, hidsize2=128):
+        super(QNetwork, self).__init__()
+
+        self.fc1_val = nn.Linear(state_size, hidsize1)
+        self.fc2_val = nn.Linear(hidsize1, hidsize2)
+        self.fc3_val = nn.Linear(hidsize2, 1)
+
+        self.fc1_adv = nn.Linear(state_size, hidsize1)
+        self.fc2_adv = nn.Linear(hidsize1, hidsize2)
+        self.fc3_adv = nn.Linear(hidsize2, action_size)
+
+    def forward(self, x):
+        val = F.relu(self.fc1_val(x))
+        val = F.relu(self.fc2_val(val))
+        val = self.fc3_val(val)
+
+        # advantage calculation
+        adv = F.relu(self.fc1_adv(x))
+        adv = F.relu(self.fc2_adv(adv))
+        adv = self.fc3_adv(adv)
+        return val + adv - adv.mean()
+
+
+class QNetwork2(nn.Module):
+    def __init__(self, state_size, action_size, seed, input_channels, hidsize1=128, hidsize2=64):
+        super(QNetwork2, self).__init__()
+        self.conv1 = nn.Conv2d(input_channels, 16, kernel_size=3, stride=1)
+        self.bn1 = nn.BatchNorm2d(16)
+        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=3)
+        self.bn2 = nn.BatchNorm2d(32)
+        self.conv3 = nn.Conv2d(32, 64, kernel_size=5, stride=3)
+        self.bn3 = nn.BatchNorm2d(64)
+
+        self.fc1_val = nn.Linear(6400, hidsize1)
+        self.fc2_val = nn.Linear(hidsize1, hidsize2)
+        self.fc3_val = nn.Linear(hidsize2, 1)
+
+        self.fc1_adv = nn.Linear(6400, hidsize1)
+        self.fc2_adv = nn.Linear(hidsize1, hidsize2)
+        self.fc3_adv = nn.Linear(hidsize2, action_size)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        x = F.relu(self.conv2(x))
+        x = F.relu(self.conv3(x))
+
+        # value function approximation
+        val = F.relu(self.fc1_val(x.view(x.size(0), -1)))
+        val = F.relu(self.fc2_val(val))
+        val = self.fc3_val(val)
+
+        # advantage calculation
+        adv = F.relu(self.fc1_adv(x.view(x.size(0), -1)))
+        adv = F.relu(self.fc2_adv(adv))
+        adv = self.fc3_adv(adv)
+        return val + adv - adv.mean()
diff --git a/flatland/core/env.py b/flatland/core/env.py
index a7e63fd4e2697996a4dbe0a684735fd46153fb1b..284afdffb6ce46ac481018af469c7d2e024fc792 100644
--- a/flatland/core/env.py
+++ b/flatland/core/env.py
@@ -3,9 +3,6 @@ The env module defines the base Environment class.
 The base Environment class is adapted from rllib.env.MultiAgentEnv
 (https://github.com/ray-project/ray).
 """
-import random
-
-from .env_observation_builder import TreeObsForRailEnv
 
 
 class Environment:
@@ -93,303 +90,3 @@ class Environment:
         function.
         """
         raise NotImplementedError()
-
-
-class RailEnv:
-    """
-    RailEnv environment class.
-
-    RailEnv is an environment inspired by a (simplified version of) a rail
-    network, in which agents (trains) have to navigate to their target
-    locations in the shortest time possible, while at the same time cooperating
-    to avoid bottlenecks.
-
-    The valid actions in the environment are:
-        0: do nothing
-        1: turn left and move to the next cell
-        2: move to the next cell in front of the agent
-        3: turn right and move to the next cell
-
-    Moving forward in a dead-end cell makes the agent turn 180 degrees and step
-    to the cell it came from.
-
-    The actions of the agents are executed in order of their handle to prevent
-    deadlocks and to allow them to learn relative priorities.
-
-    TODO: WRITE ABOUT THE REWARD FUNCTION, and possibly allow for alpha and
-    beta to be passed as parameters to __init__().
-    """
-
-    def __init__(self,
-                 rail,
-                 number_of_agents=1,
-                 custom_observation_builder=TreeObsForRailEnv):
-        """
-        Environment init.
-
-        Parameters
-        -------
-        rail : numpy.ndarray of type numpy.uint16
-            The transition matrix that defines the environment.
-        number_of_agents : int
-            Number of agents to spawn on the map.
-        custom_observation_builder: ObservationBuilder object
-            ObservationBuilder-derived object that takes this env object
-            as input as provides observation vectors for each agent.
-        """
-
-        self.rail = rail
-        self.width = rail.width
-        self.height = rail.height
-
-        self.number_of_agents = number_of_agents
-
-        self.obs_builder = custom_observation_builder(env=self)
-
-        self.actions = [0]*self.number_of_agents
-        self.rewards = [0]*self.number_of_agents
-        self.done = False
-
-        self.dones = {"__all__": False}
-        self.obs_dict = {}
-        self.rewards_dict = {}
-
-        self.agents_handles = list(range(self.number_of_agents))
-
-    def get_agent_handles(self):
-        return self.agents_handles
-
-    def reset(self):
-        self.dones = {"__all__": False}
-        for handle in self.agents_handles:
-            self.dones[handle] = False
-
-        re_generate = True
-        while re_generate:
-            valid_positions = []
-            for r in range(self.height):
-                for c in range(self.width):
-                    if self.rail.get_transitions((r, c)) > 0:
-                        valid_positions.append((r, c))
-
-            self.agents_position = random.sample(valid_positions,
-                                                 self.number_of_agents)
-            self.agents_target = random.sample(valid_positions,
-                                               self.number_of_agents)
-
-            # agents_direction must be a direction for which a solution is
-            # guaranteed.
-            self.agents_direction = [0]*self.number_of_agents
-            re_generate = False
-            for i in range(self.number_of_agents):
-                valid_movements = []
-                for direction in range(4):
-                    position = self.agents_position[i]
-                    moves = self.rail.get_transitions(
-                            (position[0], position[1], direction))
-                    for move_index in range(4):
-                        if moves[move_index]:
-                            valid_movements.append((direction, move_index))
-
-                valid_starting_directions = []
-                for m in valid_movements:
-                    new_position = self._new_position(self.agents_position[i],
-                                                      m[1])
-                    if m[0] not in valid_starting_directions and \
-                       self._path_exists(new_position, m[0],
-                                         self.agents_target[i]):
-                        valid_starting_directions.append(m[0])
-
-                if len(valid_starting_directions) == 0:
-                    re_generate = True
-                else:
-                    self.agents_direction[i] = random.sample(
-                                               valid_starting_directions, 1)[0]
-
-        # Reset the state of the observation builder with the new environment
-        self.obs_builder.reset()
-
-        # Return the new observation vectors for each agent
-        return self._get_observations()
-
-    def step(self, action_dict):
-        alpha = 1.0
-        beta = 1.0
-
-        invalid_action_penalty = -2
-        step_penalty = -1 * alpha
-        global_reward = 1 * beta
-
-        # Reset the step rewards
-        self.rewards_dict = {}
-        for handle in self.agents_handles:
-            self.rewards_dict[handle] = 0
-
-        if self.dones["__all__"]:
-            return self._get_observations(), self.rewards_dict, self.dones, {}
-
-        for i in range(len(self.agents_handles)):
-            handle = self.agents_handles[i]
-
-            if handle not in action_dict:
-                continue
-
-            action = action_dict[handle]
-
-            if action < 0 or action > 3:
-                print('ERROR: illegal action=', action,
-                      'for agent with handle=', handle)
-                return
-
-            if action > 0:
-                pos = self.agents_position[i]
-                direction = self.agents_direction[i]
-
-                movement = direction
-                if action == 1:
-                    movement = direction - 1
-                elif action == 3:
-                    movement = direction + 1
-
-                if movement < 0:
-                    movement += 4
-                if movement >= 4:
-                    movement -= 4
-
-                is_deadend = False
-                if action == 2:
-                    # compute number of possible transitions in the current
-                    # cell
-                    nbits = 0
-                    tmp = self.rail.get_transitions((pos[0], pos[1]))
-                    while tmp > 0:
-                        nbits += (tmp & 1)
-                        tmp = tmp >> 1
-                    if nbits == 1:
-                        # dead-end;  assuming the rail network is consistent,
-                        # this should match the direction the agent has come
-                        # from. But it's better to check in any case.
-                        reverse_direction = 0
-                        if direction == 0:
-                            reverse_direction = 2
-                        elif direction == 1:
-                            reverse_direction = 3
-                        elif direction == 2:
-                            reverse_direction = 0
-                        elif direction == 3:
-                            reverse_direction = 1
-
-                        valid_transition = self.rail.get_transition(
-                                            (pos[0], pos[1], direction),
-                                            reverse_direction)
-                        if valid_transition:
-                            direction = reverse_direction
-                            movement = reverse_direction
-                            is_deadend = True
-
-                new_position = self._new_position(pos, movement)
-
-                # Is it a legal move?  1) transition allows the movement in the
-                # cell,  2) the new cell is not empty (case 0),  3) the cell is
-                # free, i.e., no agent is currently in that cell
-                if new_position[1] >= self.width or\
-                   new_position[0] >= self.height or\
-                   new_position[0] < 0 or new_position[1] < 0:
-                    new_cell_isValid = False
-
-                elif self.rail.get_transitions((new_position[0], new_position[1])) > 0:
-                    new_cell_isValid = True
-                else:
-                    new_cell_isValid = False
-
-                transition_isValid = self.rail.get_transition(
-                     (pos[0], pos[1], direction),
-                     movement) or is_deadend
-
-                cell_isFree = True
-                for j in range(self.number_of_agents):
-                    if self.agents_position[j] == new_position:
-                        cell_isFree = False
-                        break
-
-                if new_cell_isValid and transition_isValid and cell_isFree:
-                    # move and change direction to face the movement that was
-                    # performed
-                    self.agents_position[i] = new_position
-                    self.agents_direction[i] = movement
-                else:
-                    # the action was not valid, add penalty
-                    self.rewards_dict[handle] += invalid_action_penalty
-
-            # if agent is not in target position, add step penalty
-            if self.agents_position[i][0] == self.agents_target[i][0] and \
-               self.agents_position[i][1] == self.agents_target[i][1]:
-                self.dones[handle] = True
-            else:
-                self.rewards_dict[handle] += step_penalty
-
-        # Check for end of episode + add global reward to all rewards!
-        num_agents_in_target_position = 0
-        for i in range(self.number_of_agents):
-            if self.agents_position[i][0] == self.agents_target[i][0] and \
-               self.agents_position[i][1] == self.agents_target[i][1]:
-                num_agents_in_target_position += 1
-
-        if num_agents_in_target_position == self.number_of_agents:
-            self.dones["__all__"] = True
-            self.rewards_dict = [r+global_reward for r in self.rewards_dict]
-
-        # Reset the step actions (in case some agent doesn't 'register_action'
-        # on the next step)
-        self.actions = [0]*self.number_of_agents
-
-        return self._get_observations(), self.rewards_dict, self.dones, {}
-
-    def _new_position(self, position, movement):
-        if movement == 0:    # NORTH
-            return (position[0]-1, position[1])
-        elif movement == 1:  # EAST
-            return (position[0], position[1] + 1)
-        elif movement == 2:  # SOUTH
-            return (position[0]+1, position[1])
-        elif movement == 3:  # WEST
-            return (position[0], position[1] - 1)
-
-    def _path_exists(self, start, direction, end):
-        # BFS - Check if a path exists between the 2 nodes
-
-        visited = set()
-        stack = [(start, direction)]
-        while stack:
-            node = stack.pop()
-            if node[0][0] == end[0] and node[0][1] == end[1]:
-                return 1
-            if node not in visited:
-                visited.add(node)
-                moves = self.rail.get_transitions((node[0][0], node[0][1], node[1]))
-                for move_index in range(4):
-                    if moves[move_index]:
-                        stack.append((self._new_position(node[0], move_index),
-                                      move_index))
-
-                # If cell is a dead-end, append previous node with reversed
-                # orientation!
-                nbits = 0
-                tmp = self.rail.get_transitions((node[0][0], node[0][1]))
-                while tmp > 0:
-                    nbits += (tmp & 1)
-                    tmp = tmp >> 1
-                if nbits == 1:
-                    stack.append((node[0], (node[1] + 2) % 4))
-
-        return 0
-
-    def _get_observations(self):
-        self.obs_dict = {}
-        for handle in self.agents_handles:
-            self.obs_dict[handle] = self.obs_builder.get(handle)
-        return self.obs_dict
-
-    def render(self):
-        # TODO:
-        pass
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index 1f97ff2b75674026ae98334a74aea6f9d1b60dbd..38e66423509e622d7cd0685b14c7c1b28f8806ca 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -1,3 +1,13 @@
+"""
+ObservationBuilder objects are objects that can be passed to environments designed for customizability.
+The ObservationBuilder-derived custom classes implement 2 functions, reset() and get() or get(handle).
+
++ Reset() is called after each environment reset, to allow for pre-computing relevant data.
+
++ Get() is called whenever an observation has to be computed, potentially for each agent independently in
+case of multi-agent environments.
+"""
+
 import numpy as np
 
 from collections import deque
@@ -6,46 +16,87 @@ from collections import deque
 
 
 class ObservationBuilder:
-    def __init__(self, env):
+    """
+    ObservationBuilder base class.
+    """
+    def __init__(self):
+        pass
+
+    def _set_env(self, env):
         self.env = env
 
     def reset(self):
+        """
+        Called after each environment reset.
+        """
         raise NotImplementedError()
 
-    def get(self, handle):
+    def get(self, handle=0):
+        """
+        Called whenever an observation has to be computed for the `env' environment, possibly
+        for each agent independently (agent id `handle').
+
+        Parameters
+        -------
+        handle : int (optional)
+            Handle of the agent for which to compute the observation vector.
+
+        Returns
+        -------
+        function
+            An observation structure, specific to the corresponding environment.
+        """
         raise NotImplementedError()
 
 
 class TreeObsForRailEnv(ObservationBuilder):
-    def __init__(self, env):
-        self.env = env
+    """
+    TreeObsForRailEnv object.
+
+    This object returns observation vectors for agents in the RailEnv environment.
+    The information is local to each agent and exploits the tree structure of the rail
+    network to simplify the representation of the state of the environment for each agent.
+    """
+    def __init__(self, max_depth):
+        self.max_depth = max_depth
 
     def reset(self):
         self.distance_map = np.inf * np.ones(shape=(self.env.number_of_agents,
                                                     self.env.height,
-                                                    self.env.width))
+                                                    self.env.width,
+                                                    4))
         self.max_dist = np.zeros(self.env.number_of_agents)
 
         for i in range(self.env.number_of_agents):
             self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i)
 
+        # Update local lookup table for all agents' target locations
+        self.location_has_target = {}
+        for loc in self.env.agents_target:
+            self.location_has_target[(loc[0], loc[1])] = 1
+
     def _distance_map_walker(self, position, target_nr):
+        """
+        Utility function to compute distance maps from each cell in the rail network (and each possible
+        orientation within it) to each agent's target cell.
+        """
         # Returns max distance to target, from the farthest away node, while filling in distance_map
 
         for ori in range(4):
-            self.distance_map[target_nr, position[0], position[1]] = 0
+            self.distance_map[target_nr, position[0], position[1], ori] = 0
 
         # Fill in the (up to) 4 neighboring nodes
         # nodes_queue = []  # list of tuples (row, col, direction, distance);
-        # direction is the direction of movement, meaning that at least a possible orientation
-        # of an agent in cell (row,col) allows a movement in direction `direction'
-        nodes_queue = deque(self._get_and_update_neighbors(position,
-                                                           target_nr, 0, enforce_target_direction=-1))
+        # direction is the direction of movement, meaning that at least a possible orientation of an agent
+        # in cell (row,col) allows a movement in direction `direction'
+        nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1))
 
         # BFS from target `position' to all the reachable nodes in the grid
         # Stop the search if the target position is re-visited, in any direction
-        visited = set([(position[0], position[1], 0), (position[0], position[1], 1),
-                       (position[0], position[1], 2), (position[0], position[1], 3)])
+        visited = set([(position[0], position[1], 0),
+                       (position[0], position[1], 1),
+                       (position[0], position[1], 2),
+                       (position[0], position[1], 3)])
 
         max_distance = 0
 
@@ -57,11 +108,9 @@ class TreeObsForRailEnv(ObservationBuilder):
             if node_id not in visited:
                 visited.add(node_id)
 
-                # From the list of possible neighbors that have at least a path to the
-                # current node, only keep those whose new orientation in the current cell
-                # would allow a transition to direction node[2]
-                valid_neighbors = self._get_and_update_neighbors(
-                    (node[0], node[1]), target_nr, node[3], node[2])
+                # From the list of possible neighbors that have at least a path to the current node, only keep those
+                # whose new orientation in the current cell would allow a transition to direction node[2]
+                valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2])
 
                 for n in valid_neighbors:
                     nodes_queue.append(n)
@@ -72,6 +121,10 @@ class TreeObsForRailEnv(ObservationBuilder):
         return max_distance
 
     def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1):
+        """
+        Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
+        minimum distances from each target cell.
+        """
         neighbors = []
 
         for direction in range(4):
@@ -121,9 +174,57 @@ class TreeObsForRailEnv(ObservationBuilder):
                     neighbors.append((new_cell[0], new_cell[1], direction, new_distance))
                     self.distance_map[target_nr, new_cell[0], new_cell[1]] = new_distance
 
+        possible_directions = [0, 1, 2, 3]
+        if enforce_target_direction >= 0:
+            # The agent must land into the current cell with orientation `enforce_target_direction'.
+            # This is only possible if the agent has arrived from the cell in the opposite direction!
+            possible_directions = [(enforce_target_direction+2) % 4]
+
+        for neigh_direction in possible_directions:
+            new_cell = self._new_position(position, neigh_direction)
+
+            if new_cell[0] >= 0 and new_cell[0] < self.env.height and \
+               new_cell[1] >= 0 and new_cell[1] < self.env.width:
+
+                desired_movement_from_new_cell = (neigh_direction+2) % 4
+
+                """
+                # Is the next cell a dead-end?
+                isNextCellDeadEnd = False
+                nbits = 0
+                tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1]))
+                while tmp > 0:
+                    nbits += (tmp & 1)
+                    tmp = tmp >> 1
+                if nbits == 1:
+                    # Dead-end!
+                    isNextCellDeadEnd = True
+                """
+
+                # Check all possible transitions in new_cell
+                for agent_orientation in range(4):
+                    # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
+                    isValid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
+                                                           desired_movement_from_new_cell)
+
+                    if isValid:
+                        """
+                        # TODO: check that it works with deadends! -- still bugged!
+                        movement = desired_movement_from_new_cell
+                        if isNextCellDeadEnd:
+                            movement = (desired_movement_from_new_cell+2) % 4
+                        """
+                        new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
+                                           current_distance+1)
+                        neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
+                        self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
+
         return neighbors
 
     def _new_position(self, position, movement):
+        """
+        Utility function that converts a compass movement over a 2D grid to new positions (r, c).
+        """
         if movement == 0:    # NORTH
             return (position[0]-1, position[1])
         elif movement == 1:  # EAST
@@ -134,8 +235,233 @@ class TreeObsForRailEnv(ObservationBuilder):
             return (position[0], position[1] - 1)
 
     def get(self, handle):
-        # TODO: compute the observation for agent `handle'
-        return []
+        """
+        Computes the current observation for agent `handle' in env
+
+        The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
+        movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
+        The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for
+        the transitions. The order is:
+            [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']
+
+
+
+
+
+        Each branch data is organized as:
+            [root node information] +
+            [recursive branch data from 'left'] +
+            [... from 'forward'] +
+            [... from 'right] +
+            [... from 'back']
+
+        Finally, each node information is composed of 5 floating point values:
+
+        #1:
+
+        #2: 1 if a target of another agent is detected between the previous node and the current one.
+
+        #3: 1 if another agent is detected between the previous node and the current one.
+
+        #4:
+
+        #5: minimum distance from node to the agent's target (when landing to the node following the corresponding
+            branch.
+
+        Missing/padding nodes are filled in with -inf (truncated).
+        Missing values in present node are filled in with +inf (truncated).
+
+
+        In case of the root node, the values are [0, 0, 0, 0, distance from agent to target].
+        In case the target node is reached, the values are [0, 0, 0, 0, 0].
+        """
+
+        # Update local lookup table for all agents' positions
+        self.location_has_agent = {}
+        for loc in self.env.agents_position:
+            self.location_has_agent[(loc[0], loc[1])] = 1
+
+        position = self.env.agents_position[handle]
+        orientation = self.env.agents_direction[handle]
+
+        # Root node - current position
+        observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
+
+        # Start from the current orientation, and see which transitions are available;
+        # organize them as [left, forward, right, back], relative to the current orientation
+        for branch_direction in [(orientation+4+i) % 4 for i in range(-1, 3)]:
+            if self.env.rail.get_transition((position[0], position[1], orientation), branch_direction):
+                new_cell = self._new_position(position, branch_direction)
+
+                branch_observation = self._explore_branch(handle, new_cell, branch_direction, 1)
+                observation = observation + branch_observation
+            else:
+                num_cells_to_fill_in = 0
+                pow4 = 1
+                for i in range(self.max_depth):
+                    num_cells_to_fill_in += pow4
+                    pow4 *= 4
+                observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf]*num_cells_to_fill_in
+
+        return observation
+
+    def _explore_branch(self, handle, position, direction, depth):
+        """
+        Utility function to compute tree-based observations.
+        """
+        # [Recursive branch opened]
+        if depth >= self.max_depth+1:
+            return []
+
+        # Continue along direction until next switch or
+        # until no transitions are possible along the current direction (i.e., dead-ends)
+        # We treat dead-ends as nodes, instead of going back, to avoid loops
+        exploring = True
+        last_isSwitch = False
+        last_isDeadEnd = False
+        # TODO: last_isTerminal = False  # wrong cell encountered
+        last_isTarget = False
+
+        other_agent_encountered = False
+        other_target_encountered = False
+        while exploring:
+            # #############################
+            # #############################
+            # Modify here to compute any useful data required to build the end node's features. This code is called
+            # for each cell visited between the previous branching node and the next switch / target / dead-end.
+
+            if position in self.location_has_agent:
+                other_agent_encountered = True
+
+            if position in self.location_has_target:
+                other_target_encountered = True
+
+            # #############################
+            # #############################
+
+            # If the target node is encountered, pick that as node. Also, no further branching is possible.
+            if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
+                last_isTarget = True
+                break
+
+            cell_transitions = self.env.rail.get_transitions((position[0], position[1], direction))
+            num_transitions = 0
+            for i in range(4):
+                if cell_transitions[i]:
+                    num_transitions += 1
+
+            exploring = False
+            if num_transitions == 1:
+                # Check if dead-end, or if we can go forward along direction
+                nbits = 0
+                tmp = self.env.rail.get_transitions((position[0], position[1]))
+                while tmp > 0:
+                    nbits += (tmp & 1)
+                    tmp = tmp >> 1
+                if nbits == 1:
+                    # Dead-end!
+                    last_isDeadEnd = True
+
+                if not last_isDeadEnd:
+                    # Keep walking through the tree along `direction'
+                    exploring = True
+
+                    for i in range(4):
+                        if cell_transitions[i]:
+                            position = self._new_position(position, i)
+                            direction = i
+                            break
+
+            elif num_transitions > 0:
+                # Switch detected
+                last_isSwitch = True
+                break
+
+            elif num_transitions == 0:
+                # Wrong cell type, but let's cover it and treat it as a dead-end, just in case
+                # TODO: last_isTerminal = True
+                break
+
+        # `position' is either a terminal node or a switch
+
+        observation = []
+
+        # #############################
+        # #############################
+        # Modify here to append new / different features for each visited cell!
+
+        if last_isTarget:
+            observation = [0,
+                           1 if other_target_encountered else 0,
+                           1 if other_agent_encountered else 0,
+                           0,
+                           0]
+
+        else:
+            observation = [0,
+                           1 if other_target_encountered else 0,
+                           1 if other_agent_encountered else 0,
+                           0,
+                           self.distance_map[handle, position[0], position[1], direction]]
+
+        # #############################
+        # #############################
+
+        # Start from the current orientation, and see which transitions are available;
+        # organize them as [left, forward, right, back], relative to the current orientation
+        for branch_direction in [(direction+4+i) % 4 for i in range(-1, 3)]:
+            if last_isDeadEnd and self.env.rail.get_transition((position[0], position[1], direction),
+                                                               (branch_direction+2) % 4):
+                # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
+                # it back
+                new_cell = self._new_position(position, (branch_direction+2) % 4)
+
+                branch_observation = self._explore_branch(handle, new_cell, (branch_direction+2) % 4, depth+1)
+                observation = observation + branch_observation
+
+            elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction),
+                                                                branch_direction):
+                new_cell = self._new_position(position, branch_direction)
+
+                branch_observation = self._explore_branch(handle, new_cell, branch_direction, depth+1)
+                observation = observation + branch_observation
+
+            else:
+                num_cells_to_fill_in = 0
+                pow4 = 1
+                for i in range(self.max_depth-depth):
+                    num_cells_to_fill_in += pow4
+                    pow4 *= 4
+                observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf]*num_cells_to_fill_in
+
+        return observation
+
+    def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0):
+        """
+        Utility function to pretty-print tree observations returned by this object.
+        """
+        if len(tree) < num_features_per_node:
+            return
+
+        depth = 0
+        tmp = len(tree)/num_features_per_node-1
+        pow4 = 4
+        while tmp > 0:
+            tmp -= pow4
+            depth += 1
+            pow4 *= 4
+
+        prompt_ = ['L:', 'F:', 'R:', 'B:']
+
+        print("  "*current_depth + prompt, tree[0:num_features_per_node])
+        child_size = (len(tree)-num_features_per_node)//4
+        for children in range(4):
+            child_tree = tree[(num_features_per_node+children*child_size):
+                              (num_features_per_node+(children+1)*child_size)]
+            self.util_print_obs_subtree(child_tree,
+                                        num_features_per_node,
+                                        prompt=prompt_[children],
+                                        current_depth=current_depth+1)
 
 
 class GlobalObsForRailEnv(ObservationBuilder):
@@ -152,8 +478,8 @@ class GlobalObsForRailEnv(ObservationBuilder):
 
         - A 4 elements array with one of encoding of the direction.
     """
-    def __init__(self, env):
-        super(GlobalObsForRailEnv, self).__init__(env)
+    def __init__(self):
+        super(GlobalObsForRailEnv, self).__init__()
 
     def reset(self):
         self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
@@ -184,196 +510,3 @@ class GlobalObsForRailEnv(ObservationBuilder):
         direction[self.env.agents_direction[handle]] = 1
 
         return self.rail_obs, obs_agents_targets_pos, direction
-
-
-class Tree_State:
-    """
-    Keep track of the current state while building the tree
-    """
-    def __init__(self, agent, position, direction, depth, distance):
-        self.agent = agent
-        self.position = position
-        self.direction = direction
-        self.depth = depth
-        self.initial_direction = None
-        self.distance = distance
-        self.data = [np.inf, np.inf, np.inf, np.inf, np.inf]
-
-
-class Node():
-    """
-    Define a tree node to get populated during search
-    """
-    def __init__(self, position, data):
-        self.n_children = 4
-        self.children = [None]*self.n_children
-        self.data = data
-        self.position = position
-
-    def insert(self, position, data, child_idx):
-        """
-        Insert new node with data
-
-        @param data node data object to insert
-        """
-        new_node = Node(position, data)
-        self.children[child_idx] = new_node
-        return new_node
-
-    def print_tree(self, i=0, depth=0):
-        """
-        Print tree content inorder
-        """
-        current_i = i
-        curr_depth = depth+1
-        if i < self.n_children:
-            if self.children[i] is not None:
-                self.children[i].print_tree(depth=curr_depth)
-            current_i += 1
-            if self.children[i] is not None:
-                self.children[i].print_tree(i, depth=curr_depth)
-
-
-"""
-
-    def get_observation(self, agent):
-        # Get the current observation for an agent
-        current_position = self.internal_position[agent]
-        #target_heading = self._compass(agent).tolist()
-        coordinate = tuple(np.transpose(self._position_to_coordinate([current_position])))
-        agent_distance = self.distance_map[agent][coordinate][0]
-        # Start tree search
-        if current_position == self.target[agent]:
-            agent_tree = Node(current_position, [-np.inf, -np.inf, -np.inf, -np.inf, -1])
-        else:
-            agent_tree = Node(current_position, [0, 0, 0, 0, agent_distance])
-
-        initial_tree_state = Tree_State(agent, current_position, -1, 0, 0)
-        self._tree_search(initial_tree_state, agent_tree, agent)
-        observation = []
-        distance_data = []
-
-        self._flatten_tree(agent_tree, observation, distance_data,  self.max_depth+1)
-        # This is probably very slow!!!!
-        #max_obs = np.max([i for i in observation if i < np.inf])
-        #if max_obs != 0:
-        #    observation = np.array(observation)/ max_obs
-
-        #print([i for i in distance_data if i >= 0])
-        observation = np.concatenate((observation, distance_data))
-        #observation = np.concatenate((observation, np.identity(5)[int(self.last_action[agent])]))
-        #return np.clip(observation / self.max_dist[agent], -1, 1)
-        return np.clip(observation / 15., -1, 1)
-
-
-
-
-    def _tree_search(self, in_tree_state, parent_node, agent):
-        if in_tree_state.depth >= self.max_depth:
-            return
-        target_distance = np.inf
-        other_target = np.inf
-        other_agent = np.inf
-        coordinate = tuple(np.transpose(self._position_to_coordinate([in_tree_state.position])))
-        curr_target_dist = self.distance_map[agent][coordinate][0]
-        forbidden_action = (in_tree_state.direction + 2) % 4
-        # Update the position
-        failed_move = 0
-        leaf_distance = in_tree_state.distance
-        for child_idx in range(4):
-            if child_idx != forbidden_action or in_tree_state.direction == -1:
-                tree_state = copy.deepcopy(in_tree_state)
-                tree_state.direction = child_idx
-                current_position, invalid_move = self._detect_path(
-                tree_state.position, tree_state.direction)
-                if tree_state.initial_direction == None:
-                    tree_state.initial_direction = child_idx
-                if not invalid_move:
-                    coordinate = tuple(np.transpose(self._position_to_coordinate([current_position])))
-                    curr_target_dist = self.distance_map[agent][coordinate][0]
-                    #if tree_state.initial_direction == None:
-                    #    tree_state.initial_direction = child_idx
-                    tree_state.position = current_position
-                    tree_state.distance += 1
-
-
-                    # Collect information at the current position
-                    detection_distance = tree_state.distance
-                    if current_position == self.target[tree_state.agent]:
-                        target_distance = detection_distance
-
-                    elif current_position in self.target:
-                        other_target = detection_distance
-
-                    if current_position in self.internal_position:
-                        other_agent = detection_distance
-
-                    tree_state.data[0] = self._min_greater_zero(target_distance, tree_state.data[0])
-                    tree_state.data[1] = self._min_greater_zero(other_target, tree_state.data[1])
-                    tree_state.data[2] = self._min_greater_zero(other_agent, tree_state.data[2])
-                    tree_state.data[3] = tree_state.distance
-                    tree_state.data[4] = self._min_greater_zero(curr_target_dist, tree_state.data[4])
-
-                    if self._switch_detection(tree_state.position):
-                        tree_state.depth += 1
-                        new_tree_state = copy.deepcopy(tree_state)
-                        new_node = parent_node.insert(tree_state.position,
-                         tree_state.data, tree_state.initial_direction)
-                        new_tree_state.initial_direction = None
-                        new_tree_state.data = [np.inf, np.inf, np.inf, np.inf, np.inf]
-                        self._tree_search(new_tree_state, new_node, agent)
-                    else:
-                        self._tree_search(tree_state, parent_node, agent)
-                else:
-                    failed_move += 1
-            if failed_move == 3 and in_tree_state.direction != -1:
-                tree_state.data[4] = self._min_greater_zero(curr_target_dist, tree_state.data[4])
-                parent_node.insert(tree_state.position, tree_state.data, tree_state.initial_direction)
-                return
-        return
-
-    def _flatten_tree(self, node, observation_vector, distance_sensor, depth):
-        if depth <= 0:
-            return
-        if node != None:
-            observation_vector.extend(node.data[:-1])
-            distance_sensor.extend([node.data[-1]])
-        else:
-            observation_vector.extend([-np.inf, -np.inf, -np.inf, -np.inf])
-            distance_sensor.extend([-np.inf])
-        for child_idx in range(4):
-            if node != None:
-                child = node.children[child_idx]
-            else:
-                child = None
-            self._flatten_tree(child, observation_vector, distance_sensor,  depth -1)
-
-
-
-    def _switch_detection(self, position):
-        # Hack to detect switches
-        # This can later directly be derived from the transition matrix
-        paths = 0
-        for i in range(4):
-            _, invalid_move = self._detect_path(position, i)
-            if not invalid_move:
-                paths +=1
-            if paths >= 3:
-                return True
-        return False
-
-
-
-
-    def _min_greater_zero(self, x, y):
-        if x <= 0 and y <= 0:
-            return 0
-        if x < 0:
-            return y
-        if y < 0:
-            return x
-        return min(x, y)
-
-
-
-"""
diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py
index d3fcf5c8467586053ca4ab624f9b8536bdfba2de..78dd9110c6ab61de1d5c38aab6e93d0180431cd5 100644
--- a/flatland/core/transition_map.py
+++ b/flatland/core/transition_map.py
@@ -118,7 +118,7 @@ class GridTransitionMap(TransitionMap):
             Width of the grid.
         height : int
             Height of the grid.
-        transitions_class : Transitions object
+        transitions : Transitions object
             The Transitions object to use to encode/decode transitions over the
             grid.
 
@@ -243,6 +243,54 @@ class GridTransitionMap(TransitionMap):
             return
         self.transitions.set_transition(self.grid[cell_id[0]][cell_id[1]], cell_id[2], transition_index, new_transition)
 
+    def save_transition_map(self, filename):
+        """
+        Save the transitions grid as `filename', in npy format.
+
+        Parameters
+        ----------
+        filename : string
+            Name of the file to which to save the transitions grid.
+
+        """
+        np.save(filename, self.grid)
+
+    def load_transition_map(self, filename, override_gridsize=True):
+        """
+        Load the transitions grid from `filename' (npy format).
+        The load function only updates the transitions grid, and possibly width and height, but the object has to be
+        initialized with the correct `transitions' object anyway.
+
+        Parameters
+        ----------
+        filename : string
+            Name of the file from which to load the transitions grid.
+        override_gridsize : bool
+            If override_gridsize=True, the width and height of the GridTransitionMap object are replaced with the size
+            of the map loaded from `filename'. If override_gridsize=False, the transitions grid is either cropped (if
+            the grid size is larger than (height,width) ) or padded with zeros (if the grid size is smaller than
+            (height,width) )
+
+        """
+        new_grid = np.load(filename)
+
+        new_height = new_grid.shape[0]
+        new_width = new_grid.shape[1]
+
+        if override_gridsize:
+            self.width = new_width
+            self.height = new_height
+            self.grid = new_grid
+
+        else:
+            if new_grid.dtype == np.uint16:
+                self.grid = np.zeros((self.height, self.width), dtype=np.uint16)
+            elif new_grid.dtype == np.uint64:
+                self.grid = np.zeros((self.height, self.width), dtype=np.uint64)
+
+            self.grid[0:min(self.height, new_height),
+                      0:min(self.width, new_width)] = new_grid[0:min(self.height, new_height),
+                                                               0:min(self.width, new_width)]
 
 # TODO: GIACOMO: is it better to provide those methods with lists of cell_ids
 # (most general implementation) or to make Grid-class specific methods for
diff --git a/flatland/envs/__init__.py b/flatland/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..36040d5cf66f1adc74faaade7e45acbd711676e0
--- /dev/null
+++ b/flatland/envs/rail_env.py
@@ -0,0 +1,721 @@
+"""
+Definition of the RailEnv environment and related level-generation functions.
+
+Generator functions are functions that take width, height and num_resets as arguments and return
+a GridTransitionMap object.
+"""
+import numpy as np
+
+from flatland.core.env import Environment
+from flatland.core.env_observation_builder import TreeObsForRailEnv
+
+from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
+from flatland.core.transition_map import GridTransitionMap
+
+
+def rail_from_manual_specifications_generator(rail_spec):
+    """
+    Utility to convert a rail given by manual specification as a map of tuples
+    (cell_type, rotation), to a transition map with the correct 16-bit
+    transitions specifications.
+
+    Parameters
+    -------
+    rail_spec : list of list of tuples
+        List (rows) of lists (columns) of tuples, each specifying a cell for
+        the RailEnv environment as (cell_type, rotation), with rotation being
+        clock-wise and in [0, 90, 180, 270].
+
+    Returns
+    -------
+    function
+        Generator function that always returns a GridTransitionMap object with
+        the matrix of correct 16-bit bitmaps for each cell.
+    """
+    def generator(width, height, num_resets=0):
+        t_utils = RailEnvTransitions()
+
+        height = len(rail_spec)
+        width = len(rail_spec[0])
+        rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
+
+        for r in range(height):
+            for c in range(width):
+                cell = rail_spec[r][c]
+                if cell[0] < 0 or cell[0] >= len(t_utils.transitions):
+                    print("ERROR - invalid cell type=", cell[0])
+                    return []
+                rail.set_transitions((r, c), t_utils.rotate_transition(
+                              t_utils.transitions[cell[0]], cell[1]))
+
+        return rail
+
+    return generator
+
+
+def rail_from_GridTransitionMap_generator(rail_map):
+    """
+    Utility to convert a rail given by a GridTransitionMap map with the correct
+    16-bit transitions specifications.
+
+    Parameters
+    -------
+    rail_map : GridTransitionMap object
+        GridTransitionMap object to return when the generator is called.
+
+    Returns
+    -------
+    function
+        Generator function that always returns the given `rail_map' object.
+    """
+    def generator(width, height, num_resets=0):
+        return rail_map
+
+    return generator
+
+
+def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames):
+    """
+    Utility to sequentially and cyclically return GridTransitionMap-s from a list of files, on each environment reset.
+
+    Parameters
+    -------
+    list_of_filenames : list
+        List of filenames with the saved grids to load.
+
+    Returns
+    -------
+    function
+        Generator function that always returns the given `rail_map' object.
+    """
+    def generator(width, height, num_resets=0):
+        t_utils = RailEnvTransitions()
+        rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils)
+        rail_map.load_transition_map(list_of_filenames[num_resets % len(list_of_filenames)], override_gridsize=False)
+
+        if rail_map.grid.dtype == np.uint64:
+            rail_map.transitions = Grid8Transitions()
+
+        return rail_map
+
+    return generator
+
+
+"""
+def generate_rail_from_list_of_manual_specifications(list_of_specifications)
+    def generator(width, height, num_resets=0):
+        return generate_rail_from_manual_specifications(list_of_specifications)
+
+    return generator
+"""
+
+
+def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
+    """
+    Dummy random level generator:
+    - fill in cells at random in [width-2, height-2]
+    - keep filling cells in among the unfilled ones, such that all transitions
+      are legit;  if no cell can be filled in without violating some
+      transitions, pick one among those that can satisfy most transitions
+      (1,2,3 or 4), and delete (+mark to be re-filled) the cells that were
+      incompatible.
+    - keep trying for a total number of insertions
+      (e.g., (W-2)*(H-2)*MAX_REPETITIONS ); if no solution is found, empty the
+      board and try again from scratch.
+    - finally pad the border of the map with dead-ends to avoid border issues.
+
+    Dead-ends are not allowed inside the grid, only at the border; however, if
+    no cell type can be inserted in a given cell (because of the neighboring
+    transitions), deadends are allowed if they solve the problem. This was
+    found to turn most un-genereatable levels into valid ones.
+
+    Parameters
+    -------
+    width : int
+        The width (number of cells) of the grid to generate.
+    height : int
+        The height (number of cells) of the grid to generate.
+
+    Returns
+    -------
+    numpy.ndarray of type numpy.uint16
+        The matrix with the correct 16-bit bitmaps for each cell.
+    """
+
+    def generator(width, height, num_resets=0):
+        t_utils = RailEnvTransitions()
+
+        transition_probability = cell_type_relative_proportion
+
+        transitions_templates_ = []
+        transition_probabilities = []
+        for i in range(len(t_utils.transitions)-1):  # don't include dead-ends
+            all_transitions = 0
+            for dir_ in range(4):
+                trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
+                all_transitions |= (trans[0] << 3) | \
+                                   (trans[1] << 2) | \
+                                   (trans[2] << 1) | \
+                                   (trans[3])
+
+            template = [int(x) for x in bin(all_transitions)[2:]]
+            template = [0]*(4-len(template)) + template
+
+            # add all rotations
+            for rot in [0, 90, 180, 270]:
+                transitions_templates_.append((template,
+                                              t_utils.rotate_transition(
+                                               t_utils.transitions[i],
+                                               rot)))
+                transition_probabilities.append(transition_probability[i])
+                template = [template[-1]]+template[:-1]
+
+        def get_matching_templates(template):
+            ret = []
+            for i in range(len(transitions_templates_)):
+                is_match = True
+                for j in range(4):
+                    if template[j] >= 0 and \
+                       template[j] != transitions_templates_[i][0][j]:
+                        is_match = False
+                        break
+                if is_match:
+                    ret.append((transitions_templates_[i][1], transition_probabilities[i]))
+            return ret
+
+        MAX_INSERTIONS = (width-2) * (height-2) * 10
+        MAX_ATTEMPTS_FROM_SCRATCH = 10
+
+        attempt_number = 0
+        while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH:
+            cells_to_fill = []
+            rail = []
+            for r in range(height):
+                rail.append([None]*width)
+                if r > 0 and r < height-1:
+                    cells_to_fill = cells_to_fill \
+                                    + [(r, c) for c in range(1, width-1)]
+
+            num_insertions = 0
+            while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0:
+                # cell = random.sample(cells_to_fill, 1)[0]
+                cell = cells_to_fill[np.random.choice(len(cells_to_fill), 1)[0]]
+                cells_to_fill.remove(cell)
+                row = cell[0]
+                col = cell[1]
+
+                # look at its neighbors and see what are the possible transitions
+                # that can be chosen from, if any.
+                valid_template = [-1, -1, -1, -1]
+
+                for el in [(0, 2, (-1, 0)),
+                           (1, 3, (0, 1)),
+                           (2, 0, (1, 0)),
+                           (3, 1, (0, -1))]:  # N, E, S, W
+                    neigh_trans = rail[row+el[2][0]][col+el[2][1]]
+                    if neigh_trans is not None:
+                        # select transition coming from facing direction el[1] and
+                        # moving to direction el[1]
+                        max_bit = 0
+                        for k in range(4):
+                            max_bit |= \
+                             t_utils.get_transition(neigh_trans, k, el[1])
+
+                        if max_bit:
+                            valid_template[el[0]] = 1
+                        else:
+                            valid_template[el[0]] = 0
+
+                possible_cell_transitions = get_matching_templates(valid_template)
+
+                if len(possible_cell_transitions) == 0:  # NO VALID TRANSITIONS
+                    # no cell can be filled in without violating some transitions
+                    # can a dead-end solve the problem?
+                    if valid_template.count(1) == 1:
+                        for k in range(4):
+                            if valid_template[k] == 1:
+                                rot = 0
+                                if k == 0:
+                                    rot = 180
+                                elif k == 1:
+                                    rot = 270
+                                elif k == 2:
+                                    rot = 0
+                                elif k == 3:
+                                    rot = 90
+
+                                rail[row][col] = t_utils.rotate_transition(
+                                                  int('0010000000000000', 2), rot)
+                                num_insertions += 1
+
+                                break
+
+                    else:
+                        # can I get valid transitions by removing a single
+                        # neighboring cell?
+                        bestk = -1
+                        besttrans = []
+                        for k in range(4):
+                            tmp_template = valid_template[:]
+                            tmp_template[k] = -1
+                            possible_cell_transitions = get_matching_templates(
+                                                         tmp_template)
+                            if len(possible_cell_transitions) > len(besttrans):
+                                besttrans = possible_cell_transitions
+                                bestk = k
+
+                        if bestk >= 0:
+                            # Replace the corresponding cell with None, append it
+                            # to cells to fill, fill in a transition in the current
+                            # cell.
+                            replace_row = row - 1
+                            replace_col = col
+                            if bestk == 1:
+                                replace_row = row
+                                replace_col = col + 1
+                            elif bestk == 2:
+                                replace_row = row + 1
+                                replace_col = col
+                            elif bestk == 3:
+                                replace_row = row
+                                replace_col = col - 1
+
+                            cells_to_fill.append((replace_row, replace_col))
+                            rail[replace_row][replace_col] = None
+
+                            possible_transitions, possible_probabilities = zip(*besttrans)
+                            possible_probabilities = [p/sum(possible_probabilities) for p in possible_probabilities]
+
+                            rail[row][col] = np.random.choice(possible_transitions,
+                                                              p=possible_probabilities)
+                            num_insertions += 1
+
+                        else:
+                            print('WARNING: still nothing!')
+                            rail[row][col] = int('0000000000000000', 2)
+                            num_insertions += 1
+                            pass
+
+                else:
+                    possible_transitions, possible_probabilities = zip(*possible_cell_transitions)
+                    possible_probabilities = [p/sum(possible_probabilities) for p in possible_probabilities]
+
+                    rail[row][col] = np.random.choice(possible_transitions,
+                                                      p=possible_probabilities)
+                    num_insertions += 1
+
+            if num_insertions == MAX_INSERTIONS:
+                # Failed to generate a valid level; try again for a number of times
+                attempt_number += 1
+            else:
+                break
+
+        if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH:
+            print('ERROR: failed to generate level')
+
+        # Finally pad the border of the map with dead-ends to avoid border issues;
+        # at most 1 transition in the neigh cell
+        for r in range(height):
+            # Check for transitions coming from [r][1] to WEST
+            max_bit = 0
+            neigh_trans = rail[r][1]
+            if neigh_trans is not None:
+                for k in range(4):
+                    neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
+                                                 & (2**4-1)
+                    max_bit = max_bit | (neigh_trans_from_direction & 1)
+            if max_bit:
+                rail[r][0] = t_utils.rotate_transition(
+                               int('0010000000000000', 2), 270)
+            else:
+                rail[r][0] = int('0000000000000000', 2)
+
+            # Check for transitions coming from [r][-2] to EAST
+            max_bit = 0
+            neigh_trans = rail[r][-2]
+            if neigh_trans is not None:
+                for k in range(4):
+                    neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
+                                                 & (2**4-1)
+                    max_bit = max_bit | (neigh_trans_from_direction & (1 << 2))
+            if max_bit:
+                rail[r][-1] = t_utils.rotate_transition(int('0010000000000000', 2),
+                                                        90)
+            else:
+                rail[r][-1] = int('0000000000000000', 2)
+
+        for c in range(width):
+            # Check for transitions coming from [1][c] to NORTH
+            max_bit = 0
+            neigh_trans = rail[1][c]
+            if neigh_trans is not None:
+                for k in range(4):
+                    neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
+                                                  & (2**4-1)
+                    max_bit = max_bit | (neigh_trans_from_direction & (1 << 3))
+            if max_bit:
+                rail[0][c] = int('0010000000000000', 2)
+            else:
+                rail[0][c] = int('0000000000000000', 2)
+
+            # Check for transitions coming from [-2][c] to SOUTH
+            max_bit = 0
+            neigh_trans = rail[-2][c]
+            if neigh_trans is not None:
+                for k in range(4):
+                    neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
+                                                 & (2**4-1)
+                    max_bit = max_bit | (neigh_trans_from_direction & (1 << 1))
+            if max_bit:
+                rail[-1][c] = t_utils.rotate_transition(
+                                int('0010000000000000', 2), 180)
+            else:
+                rail[-1][c] = int('0000000000000000', 2)
+
+        # For display only, wrong levels
+        for r in range(height):
+            for c in range(width):
+                if rail[r][c] is None:
+                    rail[r][c] = int('0000000000000000', 2)
+
+        tmp_rail = np.asarray(rail, dtype=np.uint16)
+
+        return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
+        return_rail.grid = tmp_rail
+        return return_rail
+
+    return generator
+
+
+class RailEnv(Environment):
+    """
+    RailEnv environment class.
+
+    RailEnv is an environment inspired by a (simplified version of) a rail
+    network, in which agents (trains) have to navigate to their target
+    locations in the shortest time possible, while at the same time cooperating
+    to avoid bottlenecks.
+
+    The valid actions in the environment are:
+        0: do nothing
+        1: turn left and move to the next cell
+        2: move to the next cell in front of the agent
+        3: turn right and move to the next cell
+
+    Moving forward in a dead-end cell makes the agent turn 180 degrees and step
+    to the cell it came from.
+
+    The actions of the agents are executed in order of their handle to prevent
+    deadlocks and to allow them to learn relative priorities.
+
+    TODO: WRITE ABOUT THE REWARD FUNCTION, and possibly allow for alpha and
+    beta to be passed as parameters to __init__().
+    """
+
+    def __init__(self,
+                 width,
+                 height,
+                 rail_generator=random_rail_generator(),
+                 number_of_agents=1,
+                 obs_builder_object=TreeObsForRailEnv(max_depth=2)):
+        """
+        Environment init.
+
+        Parameters
+        -------
+        rail_generator : function
+            The rail_generator function is a function that takes the width and
+            height of a  rail map along with the number of times the env has
+            been reset, and returns a GridTransitionMap object.
+            Implemented functions are:
+                random_rail_generator : generate a random rail of given size
+                rail_from_GridTransitionMap_generator(rail_map) : generate a rail from
+                                        a GridTransitionMap object
+                rail_from_manual_specifications_generator(rail_spec) : generate a rail from
+                                        a rail specifications array
+                TODO: generate_rail_from_saved_list or from list of ndarray bitmaps ---
+        width : int
+            The width of the rail map. Potentially in the future,
+            a range of widths to sample from.
+        height : int
+            The height of the rail map. Potentially in the future,
+            a range of heights to sample from.
+        number_of_agents : int
+            Number of agents to spawn on the map. Potentially in the future,
+            a range of number of agents to sample from.
+        obs_builder_object: ObservationBuilder object
+            ObservationBuilder-derived object that takes builds observation
+            vectors for each agent.
+        """
+
+        self.rail_generator = rail_generator
+        self.rail = None
+        self.width = width
+        self.height = height
+
+        self.number_of_agents = number_of_agents
+
+        self.obs_builder = obs_builder_object
+        self.obs_builder._set_env(self)
+
+        self.actions = [0]*self.number_of_agents
+        self.rewards = [0]*self.number_of_agents
+        self.done = False
+
+        self.dones = {"__all__": False}
+        self.obs_dict = {}
+        self.rewards_dict = {}
+
+        self.agents_handles = list(range(self.number_of_agents))
+
+        # self.agents_position = []
+        # self.agents_target = []
+        # self.agents_direction = []
+        self.num_resets = 0
+        self.reset()
+        self.num_resets = 0
+
+    def get_agent_handles(self):
+        return self.agents_handles
+
+    def reset(self):
+        self.rail = self.rail_generator(self.width, self.height, self.num_resets)
+        self.num_resets += 1
+
+        self.dones = {"__all__": False}
+        for handle in self.agents_handles:
+            self.dones[handle] = False
+
+        re_generate = True
+        while re_generate:
+            valid_positions = []
+            for r in range(self.height):
+                for c in range(self.width):
+                    if self.rail.get_transitions((r, c)) > 0:
+                        valid_positions.append((r, c))
+
+            # self.agents_position = random.sample(valid_positions,
+            #                                     self.number_of_agents)
+            self.agents_position = [
+                valid_positions[i] for i in
+                np.random.choice(len(valid_positions), self.number_of_agents)]
+            self.agents_target = [
+                valid_positions[i] for i in
+                np.random.choice(len(valid_positions), self.number_of_agents)]
+
+            # agents_direction must be a direction for which a solution is
+            # guaranteed.
+            self.agents_direction = [0]*self.number_of_agents
+            re_generate = False
+            for i in range(self.number_of_agents):
+                valid_movements = []
+                for direction in range(4):
+                    position = self.agents_position[i]
+                    moves = self.rail.get_transitions(
+                            (position[0], position[1], direction))
+                    for move_index in range(4):
+                        if moves[move_index]:
+                            valid_movements.append((direction, move_index))
+
+                valid_starting_directions = []
+                for m in valid_movements:
+                    new_position = self._new_position(self.agents_position[i],
+                                                      m[1])
+                    if m[0] not in valid_starting_directions and \
+                       self._path_exists(new_position, m[0],
+                                         self.agents_target[i]):
+                        valid_starting_directions.append(m[0])
+
+                if len(valid_starting_directions) == 0:
+                    re_generate = True
+                else:
+                    self.agents_direction[i] = valid_starting_directions[
+                        np.random.choice(len(valid_starting_directions), 1)[0]]
+
+        # Reset the state of the observation builder with the new environment
+        self.obs_builder.reset()
+
+        # Return the new observation vectors for each agent
+        return self._get_observations()
+
+    def step(self, action_dict):
+        alpha = 1.0
+        beta = 1.0
+
+        invalid_action_penalty = -2
+        step_penalty = -1 * alpha
+        global_reward = 1 * beta
+
+        # Reset the step rewards
+        self.rewards_dict = dict()
+        for handle in self.agents_handles:
+            self.rewards_dict[handle] = 0
+
+        if self.dones["__all__"]:
+            return self._get_observations(), self.rewards_dict, self.dones, {}
+
+        for i in range(len(self.agents_handles)):
+            handle = self.agents_handles[i]
+
+            if handle not in action_dict:
+                continue
+
+            action = action_dict[handle]
+
+            if action < 0 or action > 3:
+                print('ERROR: illegal action=', action,
+                      'for agent with handle=', handle)
+                return
+
+            if action > 0:
+                pos = self.agents_position[i]
+                direction = self.agents_direction[i]
+
+                movement = direction
+                if action == 1:
+                    movement = direction - 1
+                elif action == 3:
+                    movement = direction + 1
+
+                if movement < 0:
+                    movement += 4
+                if movement >= 4:
+                    movement -= 4
+
+                is_deadend = False
+                if action == 2:
+                    # compute number of possible transitions in the current
+                    # cell
+                    nbits = 0
+                    tmp = self.rail.get_transitions((pos[0], pos[1]))
+                    while tmp > 0:
+                        nbits += (tmp & 1)
+                        tmp = tmp >> 1
+                    if nbits == 1:
+                        # dead-end;  assuming the rail network is consistent,
+                        # this should match the direction the agent has come
+                        # from. But it's better to check in any case.
+                        reverse_direction = 0
+                        if direction == 0:
+                            reverse_direction = 2
+                        elif direction == 1:
+                            reverse_direction = 3
+                        elif direction == 2:
+                            reverse_direction = 0
+                        elif direction == 3:
+                            reverse_direction = 1
+
+                        valid_transition = self.rail.get_transition(
+                                            (pos[0], pos[1], direction),
+                                            reverse_direction)
+                        if valid_transition:
+                            direction = reverse_direction
+                            movement = reverse_direction
+                            is_deadend = True
+
+                new_position = self._new_position(pos, movement)
+
+                # Is it a legal move?  1) transition allows the movement in the
+                # cell,  2) the new cell is not empty (case 0),  3) the cell is
+                # free, i.e., no agent is currently in that cell
+                if new_position[1] >= self.width or\
+                   new_position[0] >= self.height or\
+                   new_position[0] < 0 or new_position[1] < 0:
+                    new_cell_isValid = False
+
+                elif self.rail.get_transitions((new_position[0], new_position[1])) > 0:
+                    new_cell_isValid = True
+                else:
+                    new_cell_isValid = False
+
+                transition_isValid = self.rail.get_transition(
+                     (pos[0], pos[1], direction),
+                     movement) or is_deadend
+
+                cell_isFree = True
+                for j in range(self.number_of_agents):
+                    if self.agents_position[j] == new_position:
+                        cell_isFree = False
+                        break
+
+                if new_cell_isValid and transition_isValid and cell_isFree:
+                    # move and change direction to face the movement that was
+                    # performed
+                    self.agents_position[i] = new_position
+                    self.agents_direction[i] = movement
+                else:
+                    # the action was not valid, add penalty
+                    self.rewards_dict[handle] += invalid_action_penalty
+
+            # if agent is not in target position, add step penalty
+            if self.agents_position[i][0] == self.agents_target[i][0] and \
+               self.agents_position[i][1] == self.agents_target[i][1]:
+                self.dones[handle] = True
+            else:
+                self.rewards_dict[handle] += step_penalty
+
+        # Check for end of episode + add global reward to all rewards!
+        num_agents_in_target_position = 0
+        for i in range(self.number_of_agents):
+            if self.agents_position[i][0] == self.agents_target[i][0] and \
+               self.agents_position[i][1] == self.agents_target[i][1]:
+                num_agents_in_target_position += 1
+
+        if num_agents_in_target_position == self.number_of_agents:
+            self.dones["__all__"] = True
+            self.rewards_dict = [r+global_reward for r in self.rewards_dict]
+
+        # Reset the step actions (in case some agent doesn't 'register_action'
+        # on the next step)
+        self.actions = [0]*self.number_of_agents
+
+        return self._get_observations(), self.rewards_dict, self.dones, {}
+
+    def _new_position(self, position, movement):
+        if movement == 0:    # NORTH
+            return (position[0]-1, position[1])
+        elif movement == 1:  # EAST
+            return (position[0], position[1] + 1)
+        elif movement == 2:  # SOUTH
+            return (position[0]+1, position[1])
+        elif movement == 3:  # WEST
+            return (position[0], position[1] - 1)
+
+    def _path_exists(self, start, direction, end):
+        # BFS - Check if a path exists between the 2 nodes
+
+        visited = set()
+        stack = [(start, direction)]
+        while stack:
+            node = stack.pop()
+            if node[0][0] == end[0] and node[0][1] == end[1]:
+                return 1
+            if node not in visited:
+                visited.add(node)
+                moves = self.rail.get_transitions((node[0][0], node[0][1], node[1]))
+                for move_index in range(4):
+                    if moves[move_index]:
+                        stack.append((self._new_position(node[0], move_index),
+                                      move_index))
+
+                # If cell is a dead-end, append previous node with reversed
+                # orientation!
+                nbits = 0
+                tmp = self.rail.get_transitions((node[0][0], node[0][1]))
+                while tmp > 0:
+                    nbits += (tmp & 1)
+                    tmp = tmp >> 1
+                if nbits == 1:
+                    stack.append((node[0], (node[1] + 2) % 4))
+
+        return 0
+
+    def _get_observations(self):
+        self.obs_dict = {}
+        for handle in self.agents_handles:
+            self.obs_dict[handle] = self.obs_builder.get(handle)
+        return self.obs_dict
+
+    def render(self):
+        # TODO:
+        pass
diff --git a/flatland/utils/rail_env_generator.py b/flatland/utils/rail_env_generator.py
deleted file mode 100644
index 69e5b831be67c6a808e22b5413255789acf90f27..0000000000000000000000000000000000000000
--- a/flatland/utils/rail_env_generator.py
+++ /dev/null
@@ -1,307 +0,0 @@
-"""
-The rail_env_generator module defines provides utilities to generate env
-bitmaps for the RailEnv environment.
-"""
-import random
-import numpy as np
-
-from flatland.core.transitions import RailEnvTransitions
-from flatland.core.transition_map import GridTransitionMap
-
-
-def generate_rail_from_manual_specifications(rail_spec):
-    """
-    Utility to convert a rail given by manual specification as a map of tuples
-    (cell_type, rotation), to a transition map with the correct 16-bit
-    transitions specifications.
-
-    Parameters
-    -------
-    rail_spec : list of list of tuples
-        List (rows) of lists (columns) of tuples, each specifying a cell for
-        the RailEnv environment as (cell_type, rotation), with rotation being
-        clock-wise and in [0, 90, 180, 270].
-
-    Returns
-    -------
-    numpy.ndarray of type numpy.uint16
-        The matrix with the correct 16-bit bitmaps for each cell.
-    """
-    t_utils = RailEnvTransitions()
-
-    height = len(rail_spec)
-    width = len(rail_spec[0])
-    rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
-
-    for r in range(height):
-        for c in range(width):
-            cell = rail_spec[r][c]
-            if cell[0] < 0 or cell[0] >= len(t_utils.transitions):
-                print("ERROR - invalid cell type=", cell[0])
-                return []
-            rail.set_transitions((r, c), t_utils.rotate_transition(
-                          t_utils.transitions[cell[0]], cell[1]))
-
-    return rail
-
-
-def generate_random_rail(width, height):
-    """
-    Dummy random level generator:
-    - fill in cells at random in [width-2, height-2]
-    - keep filling cells in among the unfilled ones, such that all transitions
-      are legit;  if no cell can be filled in without violating some
-      transitions, pick one among those that can satisfy most transitions
-      (1,2,3 or 4), and delete (+mark to be re-filled) the cells that were
-      incompatible.
-    - keep trying for a total number of insertions
-      (e.g., (W-2)*(H-2)*MAX_REPETITIONS ); if no solution is found, empty the
-      board and try again from scratch.
-    - finally pad the border of the map with dead-ends to avoid border issues.
-
-    Dead-ends are not allowed inside the grid, only at the border; however, if
-    no cell type can be inserted in a given cell (because of the neighboring
-    transitions), deadends are allowed if they solve the problem. This was
-    found to turn most un-genereatable levels into valid ones.
-
-    Parameters
-    -------
-    width : int
-        The width (number of cells) of the grid to generate.
-    height : int
-        The height (number of cells) of the grid to generate.
-
-    Returns
-    -------
-    numpy.ndarray of type numpy.uint16
-        The matrix with the correct 16-bit bitmaps for each cell.
-    """
-
-    t_utils = RailEnvTransitions()
-
-    transitions_templates_ = []
-    for i in range(len(t_utils.transitions)-1):  # don't include dead-ends
-        all_transitions = 0
-        for dir_ in range(4):
-            trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
-            all_transitions |= (trans[0] << 3) | \
-                               (trans[1] << 2) | \
-                               (trans[2] << 1) | \
-                               (trans[3])
-
-        template = [int(x) for x in bin(all_transitions)[2:]]
-        template = [0]*(4-len(template)) + template
-
-        # add all rotations
-        for rot in [0, 90, 180, 270]:
-            transitions_templates_.append((template,
-                                          t_utils.rotate_transition(
-                                           t_utils.transitions[i],
-                                           rot)))
-            template = [template[-1]]+template[:-1]
-
-    def get_matching_templates(template):
-        ret = []
-        for i in range(len(transitions_templates_)):
-            is_match = True
-            for j in range(4):
-                if template[j] >= 0 and \
-                   template[j] != transitions_templates_[i][0][j]:
-                    is_match = False
-                    break
-            if is_match:
-                ret.append(transitions_templates_[i][1])
-        return ret
-
-    MAX_INSERTIONS = (width-2) * (height-2) * 10
-    MAX_ATTEMPTS_FROM_SCRATCH = 10
-
-    attempt_number = 0
-    while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH:
-        cells_to_fill = []
-        rail = []
-        for r in range(height):
-            rail.append([None]*width)
-            if r > 0 and r < height-1:
-                cells_to_fill = cells_to_fill \
-                                + [(r, c) for c in range(1, width-1)]
-
-        num_insertions = 0
-        while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0:
-            cell = random.sample(cells_to_fill, 1)[0]
-            cells_to_fill.remove(cell)
-            row = cell[0]
-            col = cell[1]
-
-            # look at its neighbors and see what are the possible transitions
-            # that can be chosen from, if any.
-            valid_template = [-1, -1, -1, -1]
-
-            for el in [(0, 2, (-1, 0)),
-                       (1, 3, (0, 1)),
-                       (2, 0, (1, 0)),
-                       (3, 1, (0, -1))]:  # N, E, S, W
-                neigh_trans = rail[row+el[2][0]][col+el[2][1]]
-                if neigh_trans is not None:
-                    # select transition coming from facing direction el[1] and
-                    # moving to direction el[1]
-                    max_bit = 0
-                    for k in range(4):
-                        max_bit |= \
-                         t_utils.get_transition(neigh_trans, k, el[1])
-
-                    if max_bit:
-                        valid_template[el[0]] = 1
-                    else:
-                        valid_template[el[0]] = 0
-
-            possible_cell_transitions = get_matching_templates(valid_template)
-
-            if len(possible_cell_transitions) == 0:  # NO VALID TRANSITIONS
-                # no cell can be filled in without violating some transitions
-                # can a dead-end solve the problem?
-                if valid_template.count(1) == 1:
-                    for k in range(4):
-                        if valid_template[k] == 1:
-                            rot = 0
-                            if k == 0:
-                                rot = 180
-                            elif k == 1:
-                                rot = 270
-                            elif k == 2:
-                                rot = 0
-                            elif k == 3:
-                                rot = 90
-
-                            rail[row][col] = t_utils.rotate_transition(
-                                              int('0000000000100000', 2), rot)
-                            num_insertions += 1
-
-                            break
-
-                else:
-                    # can I get valid transitions by removing a single
-                    # neighboring cell?
-                    bestk = -1
-                    besttrans = []
-                    for k in range(4):
-                        tmp_template = valid_template[:]
-                        tmp_template[k] = -1
-                        possible_cell_transitions = get_matching_templates(
-                                                     tmp_template)
-                        if len(possible_cell_transitions) > len(besttrans):
-                            besttrans = possible_cell_transitions
-                            bestk = k
-
-                    if bestk >= 0:
-                        # Replace the corresponding cell with None, append it
-                        # to cells to fill, fill in a transition in the current
-                        # cell.
-                        replace_row = row - 1
-                        replace_col = col
-                        if bestk == 1:
-                            replace_row = row
-                            replace_col = col + 1
-                        elif bestk == 2:
-                            replace_row = row + 1
-                            replace_col = col
-                        elif bestk == 3:
-                            replace_row = row
-                            replace_col = col - 1
-
-                        cells_to_fill.append((replace_row, replace_col))
-                        rail[replace_row][replace_col] = None
-
-                        rail[row][col] = random.sample(
-                                                     besttrans, 1)[0]
-                        num_insertions += 1
-
-                    else:
-                        print('WARNING: still nothing!')
-                        rail[row][col] = int('0000000000000000', 2)
-                        num_insertions += 1
-                        pass
-
-            else:
-                rail[row][col] = random.sample(
-                                             possible_cell_transitions, 1)[0]
-                num_insertions += 1
-
-        if num_insertions == MAX_INSERTIONS:
-            # Failed to generate a valid level; try again for a number of times
-            attempt_number += 1
-        else:
-            break
-
-    if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH:
-        print('ERROR: failed to generate level')
-
-    # Finally pad the border of the map with dead-ends to avoid border issues;
-    # at most 1 transition in the neigh cell
-    for r in range(height):
-        # Check for transitions coming from [r][1] to WEST
-        max_bit = 0
-        neigh_trans = rail[r][1]
-        if neigh_trans is not None:
-            for k in range(4):
-                neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
-                                             & (2**4-1)
-                max_bit = max_bit | (neigh_trans_from_direction & 1)
-        if max_bit:
-            rail[r][0] = t_utils.rotate_transition(
-                           int('0000000000100000', 2), 270)
-        else:
-            rail[r][0] = int('0000000000000000', 2)
-
-        # Check for transitions coming from [r][-2] to EAST
-        max_bit = 0
-        neigh_trans = rail[r][-2]
-        if neigh_trans is not None:
-            for k in range(4):
-                neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
-                                             & (2**4-1)
-                max_bit = max_bit | (neigh_trans_from_direction & (1 << 2))
-        if max_bit:
-            rail[r][-1] = t_utils.rotate_transition(int('0000000000100000', 2),
-                                                    90)
-        else:
-            rail[r][-1] = int('0000000000000000', 2)
-
-    for c in range(width):
-        # Check for transitions coming from [1][c] to NORTH
-        max_bit = 0
-        neigh_trans = rail[1][c]
-        if neigh_trans is not None:
-            for k in range(4):
-                neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
-                                              & (2**4-1)
-                max_bit = max_bit | (neigh_trans_from_direction & (1 << 3))
-        if max_bit:
-            rail[0][c] = int('0000000000100000', 2)
-        else:
-            rail[0][c] = int('0000000000000000', 2)
-
-        # Check for transitions coming from [-2][c] to SOUTH
-        max_bit = 0
-        neigh_trans = rail[-2][c]
-        if neigh_trans is not None:
-            for k in range(4):
-                neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
-                                             & (2**4-1)
-                max_bit = max_bit | (neigh_trans_from_direction & (1 << 1))
-        if max_bit:
-            rail[-1][c] = t_utils.rotate_transition(
-                            int('0000000000100000', 2), 180)
-        else:
-            rail[-1][c] = int('0000000000000000', 2)
-
-    # For display only, wrong levels
-    for r in range(height):
-        for c in range(width):
-            if rail[r][c] is None:
-                rail[r][c] = int('0000000000000000', 2)
-
-    tmp_rail = np.asarray(rail, dtype=np.uint16)
-    return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
-    return_rail.grid = tmp_rail
-    return return_rail
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index d0a78917dcee833eac2d5ac7567436546635d044..9f008f00c261f76e71771732aa02c3c3071f9542 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -6,6 +6,8 @@ import xarray as xr
 import matplotlib.pyplot as plt
 
 
+# TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv!
+
 class RenderTool(object):
     Visit = recordtype("Visit", ["rc", "iDir", "iDepth", "prev"])
 
@@ -401,7 +403,7 @@ class RenderTool(object):
         # cell_size is a bit pointless with matplotlib - it does not relate to pixels,
         # so for now I've changed it to 1 (from 10)
         cell_size = 1
-
+        plt.clf()
         # if oFigure is None:
         #    oFigure = plt.figure()
 
@@ -549,7 +551,9 @@ class RenderTool(object):
         plt.xlim([0, env.width * cell_size])
         plt.ylim([-env.height * cell_size, 0])
         if show:
-            plt.show()
+            plt.show(block=False)
+            plt.pause(0.00001)
+            return
 
     def _draw_square(self, center, size, color):
         x0 = center[0]-size/2
diff --git a/images/basic-env.png b/images/basic-env.png
index 0b21b26887236b4e7d7e82b34bfa074ec9d05c38..850d6ecad2d1adb6d3d4f829116acee67b9441db 100644
Binary files a/images/basic-env.png and b/images/basic-env.png differ
diff --git a/images/env-path.png b/images/env-path.png
index 5f49e744754237889dd7331213d3084cb19b1555..95b9faa9fbe78e36c49216058274dbd18495cc12 100644
Binary files a/images/env-path.png and b/images/env-path.png differ
diff --git a/images/env-tree-graph.png b/images/env-tree-graph.png
index 9b2a2d6a9cc4792c962e0fdbdd87f233aaac7d7d..f33b5f4c8d69ab2c028e6cc1689f4ccbb7600ce3 100644
Binary files a/images/env-tree-graph.png and b/images/env-tree-graph.png differ
diff --git a/images/env-tree-spatial.png b/images/env-tree-spatial.png
index 06f2054027c5da517c8eff0b2c142dd183a7e3fb..54ac9bfc0c2cb0853a319368add4d6ac5514fd28 100644
Binary files a/images/env-tree-spatial.png and b/images/env-tree-spatial.png differ
diff --git a/tests/test_env_observation_builder.py b/tests/test_env_observation_builder.py
index 1a797de858beb064d1d5338b746bc0047adf6ba8..44dbfc6d8f14a7293e148f125b47eed66c9ca08d 100644
--- a/tests/test_env_observation_builder.py
+++ b/tests/test_env_observation_builder.py
@@ -3,7 +3,7 @@
 
 from flatland.core.env_observation_builder import GlobalObsForRailEnv
 from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
-from flatland.core.env import RailEnv
+from flatland.envs.rail_env import RailEnv, rail_from_GridTransitionMap_generator
 from flatland.utils.rendertools import *
 
 """Tests for `flatland` package."""
@@ -57,29 +57,30 @@ def test_global_obs():
     rail = GridTransitionMap(width=rail_map.shape[1],
                              height=rail_map.shape[0], transitions=transitions)
     rail.grid = rail_map
-    env = RailEnv(rail, number_of_agents=1)
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_GridTransitionMap_generator(rail),
+                  number_of_agents=1,
+                  obs_builder_object=GlobalObsForRailEnv())
 
-    env.reset()
+    global_obs = env.reset()
     # env_renderer = RenderTool(env)
     # env_renderer.renderEnv(show=True)
 
-    global_obs = GlobalObsForRailEnv(env)
-    global_obs.reset()
-    assert(global_obs.rail_obs.shape == rail_map.shape + (16,))
+    # global_obs.reset()
+    assert(global_obs[0][0].shape == rail_map.shape + (16,))
 
     rail_map_recons = np.zeros_like(rail_map)
-    for i in range(global_obs.rail_obs.shape[0]):
-        for j in range(global_obs.rail_obs.shape[1]):
-            rail_map_recons[i,j] = int(
-                ''.join(global_obs.rail_obs[i, j].astype(int).astype(str)), 2)
+    for i in range(global_obs[0][0].shape[0]):
+        for j in range(global_obs[0][0].shape[1]):
+            rail_map_recons[i, j] = int(
+                ''.join(global_obs[0][0][i, j].astype(int).astype(str)), 2)
 
     assert(rail_map_recons.all() == rail_map.all())
 
-    obs = global_obs.get(0)
-
     # If this assertion is wrong, it means that the observation returned
     # places the agent on an empty cell
-    assert(np.sum(rail_map * obs[1][0]) > 0)
+    assert(np.sum(rail_map * global_obs[0][1][0]) > 0)
 
 
 
diff --git a/tests/test_environments.py b/tests/test_environments.py
index a89133d25378fec20bb44e43b7c43ce0432f789a..ea8748b8aa4b50a1371a013be98f3b42d0d01228 100644
--- a/tests/test_environments.py
+++ b/tests/test_environments.py
@@ -1,10 +1,11 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
+import numpy as np
 
-from flatland.core.env import RailEnv
+from flatland.envs.rail_env import RailEnv, rail_from_GridTransitionMap_generator
 from flatland.core.transitions import Grid4Transitions
 from flatland.core.transition_map import GridTransitionMap
-import numpy as np
+from flatland.core.env_observation_builder import GlobalObsForRailEnv
 
 """Tests for `flatland` package."""
 
@@ -46,7 +47,12 @@ def test_rail_environment_single_agent():
 
     rail = GridTransitionMap(width=3, height=3, transitions=transitions)
     rail.grid = rail_map
-    rail_env = RailEnv(rail, number_of_agents=1)
+    rail_env = RailEnv(width=3,
+                       height=3,
+                       rail_generator=rail_from_GridTransitionMap_generator(rail),
+                       number_of_agents=1,
+                       obs_builder_object=GlobalObsForRailEnv())
+
     for _ in range(200):
         _ = rail_env.reset()
 
@@ -118,7 +124,11 @@ def test_dead_end():
                              transitions=transitions)
 
     rail.grid = rail_map
-    rail_env = RailEnv(rail, number_of_agents=1)
+    rail_env = RailEnv(width=rail_map.shape[1],
+                       height=rail_map.shape[0],
+                       rail_generator=rail_from_GridTransitionMap_generator(rail),
+                       number_of_agents=1,
+                       obs_builder_object=GlobalObsForRailEnv())
 
     def check_consistency(rail_env):
         # We run step to check that trains do not move anymore
@@ -141,14 +151,14 @@ def test_dead_end():
 
     # We try the configuration in the 4 directions:
     rail_env.reset()
-    rail_env.agents_target[0] = [0, 0]
-    rail_env.agents_position[0] = [0, 2]
+    rail_env.agents_target[0] = (0, 0)
+    rail_env.agents_position[0] = (0, 2)
     rail_env.agents_direction[0] = 1
     check_consistency(rail_env)
 
     rail_env.reset()
-    rail_env.agents_target[0] = [0, 4]
-    rail_env.agents_position[0] = [0, 2]
+    rail_env.agents_target[0] = (0, 4)
+    rail_env.agents_position[0] = (0, 2)
     rail_env.agents_direction[0] = 3
     check_consistency(rail_env)
 
@@ -164,16 +174,20 @@ def test_dead_end():
                              transitions=transitions)
 
     rail.grid = rail_map
-    rail_env = RailEnv(rail, number_of_agents=1)
+    rail_env = RailEnv(width=rail_map.shape[1],
+                       height=rail_map.shape[0],
+                       rail_generator=rail_from_GridTransitionMap_generator(rail),
+                       number_of_agents=1,
+                       obs_builder_object=GlobalObsForRailEnv())
 
     rail_env.reset()
-    rail_env.agents_target[0] = [0, 0]
-    rail_env.agents_position[0] = [2, 0]
+    rail_env.agents_target[0] = (0, 0)
+    rail_env.agents_position[0] = (2, 0)
     rail_env.agents_direction[0] = 2
     check_consistency(rail_env)
 
     rail_env.reset()
-    rail_env.agents_target[0] = [4, 0]
-    rail_env.agents_position[0] = [2, 0]
+    rail_env.agents_target[0] = (4, 0)
+    rail_env.agents_position[0] = (2, 0)
     rail_env.agents_direction[0] = 0
     check_consistency(rail_env)
diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py
index ae9a9e1867a9e9d65877e81e73370550a233d206..e45b7d1815365afda98f699d628a7e6f51c92395 100644
--- a/tests/test_rendertools.py
+++ b/tests/test_rendertools.py
@@ -4,15 +4,14 @@
 Tests for `flatland` package.
 """
 
-from flatland.core.env import RailEnv
+from flatland.envs.rail_env import RailEnv, random_rail_generator
 import numpy as np
-import random
 import os
 
 import matplotlib.pyplot as plt
 
-from flatland.utils import rail_env_generator
 import flatland.utils.rendertools as rt
+from flatland.core.env_observation_builder import GlobalObsForRailEnv
 
 
 def checkFrozenImage(sFileImage):
@@ -36,9 +35,12 @@ def checkFrozenImage(sFileImage):
 
 
 def test_render_env():
-    random.seed(100)
-    oRail = rail_env_generator.generate_random_rail(10, 10)
-    oEnv = RailEnv(oRail, number_of_agents=2)
+    # random.seed(100)
+    np.random.seed(100)
+    oEnv = RailEnv(width=10, height=10,
+                   rail_generator=random_rail_generator(),
+                   number_of_agents=2,
+                   obs_builder_object=GlobalObsForRailEnv())
     oEnv.reset()
     oRT = rt.RenderTool(oEnv)
     plt.figure(figsize=(10, 10))