diff --git a/src/agent/__pycache__/dueling_double_dqn.cpython-36.pyc b/src/agent/__pycache__/dueling_double_dqn.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..41afc8cad7a82889d0118d5fd836cf1f15af34de
Binary files /dev/null and b/src/agent/__pycache__/dueling_double_dqn.cpython-36.pyc differ
diff --git a/src/agent/__pycache__/model.cpython-36.pyc b/src/agent/__pycache__/model.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a13378b2cbe5716b1bf04b774b5fce12ccef39e7
Binary files /dev/null and b/src/agent/__pycache__/model.cpython-36.pyc differ
diff --git a/src/agent/dueling_double_dqn.py b/src/agent/dueling_double_dqn.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d82fded3552e4b9d72c22dc4f85fa1fde574e48
--- /dev/null
+++ b/src/agent/dueling_double_dqn.py
@@ -0,0 +1,512 @@
+import torch
+import torch.optim as optim
+
+BUFFER_SIZE = int(1e5)  # replay buffer size
+BATCH_SIZE = 512  # minibatch size
+GAMMA = 0.99  # discount factor 0.99
+TAU = 0.5e-3  # for soft update of target parameters
+LR = 0.5e-4  # learning rate 0.5e-4 works
+
+# how often to update the network
+UPDATE_EVERY = 20
+UPDATE_EVERY_FINAL = 10
+UPDATE_EVERY_AGENT_CANT_CHOOSE = 200
+
+
+double_dqn = True  # If using double dqn algorithm
+input_channels = 5  # Number of Input channels
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+device = torch.device("cpu")
+print(device)
+
+USE_OPTIMIZER = optim.Adam
+# USE_OPTIMIZER = optim.RMSprop
+print(USE_OPTIMIZER)
+
+
+class Agent:
+    """Interacts with and learns from the environment."""
+
+    def __init__(self, state_size, action_size, net_type, seed, double_dqn=True, input_channels=5):
+        """Initialize an Agent object.
+
+        Params
+        ======
+            state_size (int): dimension of each state
+            action_size (int): dimension of each action
+            seed (int): random seed
+        """
+        self.state_size = state_size
+        self.action_size = action_size
+        self.seed = random.seed(seed)
+        self.version = net_type
+        self.double_dqn = double_dqn
+        # Q-Network
+        if self.version == "Conv":
+            self.qnetwork_local = QNetwork2(state_size, action_size, seed, input_channels).to(device)
+            self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
+        else:
+            self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device)
+            self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
+
+        self.optimizer = USE_OPTIMIZER(self.qnetwork_local.parameters(), lr=LR)
+
+        # Replay memory
+        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
+        self.memory_final = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
+        self.memory_agent_can_not_choose = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
+
+        self.final_step = {}
+
+        # Initialize time step (for updating every UPDATE_EVERY steps)
+        self.t_step = 0
+        self.t_step_final = 0
+        self.t_step_agent_can_not_choose = 0
+
+    def save(self, filename):
+        torch.save(self.qnetwork_local.state_dict(), filename + ".local")
+        torch.save(self.qnetwork_target.state_dict(), filename + ".target")
+
+    def load(self, filename):
+        if os.path.exists(filename + ".local"):
+            self.qnetwork_local.load_state_dict(torch.load(filename + ".local"))
+            print(filename + ".local -> ok")
+        if os.path.exists(filename + ".target"):
+            self.qnetwork_target.load_state_dict(torch.load(filename + ".target"))
+            print(filename + ".target -> ok")
+        self.optimizer = USE_OPTIMIZER(self.qnetwork_local.parameters(), lr=LR)
+
+    def _update_model(self, switch=0):
+        # Learn every UPDATE_EVERY time steps.
+        # If enough samples are available in memory, get random subset and learn
+        if switch == 0:
+            self.t_step = (self.t_step + 1) % UPDATE_EVERY
+            if self.t_step == 0:
+                if len(self.memory) > BATCH_SIZE:
+                    experiences = self.memory.sample()
+                    self.learn(experiences, GAMMA)
+        elif switch == 1:
+            self.t_step_final = (self.t_step_final + 1) % UPDATE_EVERY_FINAL
+            if self.t_step_final == 0:
+                if len(self.memory_final) > BATCH_SIZE:
+                    experiences = self.memory_final.sample()
+                    self.learn(experiences, GAMMA)
+        else:
+            # If enough samples are available in memory_agent_can_not_choose, get random subset and learn
+            self.t_step_agent_can_not_choose = (self.t_step_agent_can_not_choose + 1) % UPDATE_EVERY_AGENT_CANT_CHOOSE
+            if self.t_step_agent_can_not_choose == 0:
+                if len(self.memory_agent_can_not_choose) > BATCH_SIZE:
+                    experiences = self.memory_agent_can_not_choose.sample()
+                    self.learn(experiences, GAMMA)
+
+    def step(self, state, action, reward, next_state, done):
+        # Save experience in replay memory
+        self.memory.add(state, action, reward, next_state, done)
+        self._update_model(0)
+
+    def step_agent_can_not_choose(self, state, action, reward, next_state, done):
+        # Save experience in replay memory_agent_can_not_choose
+        self.memory_agent_can_not_choose.add(state, action, reward, next_state, done)
+        self._update_model(2)
+
+    def add_final_step(self, agent_handle, state, action, reward, next_state, done):
+        if self.final_step.get(agent_handle) is None:
+            self.final_step.update({agent_handle: [state, action, reward, next_state, done]})
+
+    def make_final_step(self, additional_reward=0):
+        for _, item in self.final_step.items():
+            state = item[0]
+            action = item[1]
+            reward = item[2] + additional_reward
+            next_state = item[3]
+            done = item[4]
+            self.memory_final.add(state, action, reward, next_state, done)
+            self._update_model(1)
+        self._reset_final_step()
+
+    def _reset_final_step(self):
+        self.final_step = {}
+
+    def act(self, state, eps=0.):
+        """Returns actions for given state as per current policy.
+
+        Params
+        ======
+            state (array_like): current state
+            eps (float): epsilon, for epsilon-greedy action selection
+        """
+        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
+        self.qnetwork_local.eval()
+        with torch.no_grad():
+            action_values = self.qnetwork_local(state)
+        self.qnetwork_local.train()
+
+        # Epsilon-greedy action selection
+        if random.random() > eps:
+            return np.argmax(action_values.cpu().data.numpy())
+        else:
+            return random.choice(np.arange(self.action_size))
+
+    def learn(self, experiences, gamma):
+
+        """Update value parameters using given batch of experience tuples.
+
+        Params
+        ======
+            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
+            gamma (float): discount factor
+        """
+        states, actions, rewards, next_states, dones = experiences
+
+        # Get expected Q values from local model
+        Q_expected = self.qnetwork_local(states).gather(1, actions)
+
+        if self.double_dqn:
+            # Double DQN
+            q_best_action = self.qnetwork_local(next_states).max(1)[1]
+            Q_targets_next = self.qnetwork_target(next_states).gather(1, q_best_action.unsqueeze(-1))
+        else:
+            # DQN
+            Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(-1)
+
+            # Compute Q targets for current states
+
+        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
+
+        # Compute loss
+        loss = F.mse_loss(Q_expected, Q_targets)
+        # Minimize the loss
+        self.optimizer.zero_grad()
+        loss.backward()
+        self.optimizer.step()
+
+        # ------------------- update target network ------------------- #
+        self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)
+
+    def soft_update(self, local_model, target_model, tau):
+        """Soft update model parameters.
+        θ_target = τ*θ_local + (1 - τ)*θ_target
+
+        Params
+        ======
+            local_model (PyTorch model): weights will be copied from
+            target_model (PyTorch model): weights will be copied to
+            tau (float): interpolation parameter
+        """
+        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
+            target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)
+
+
+class ReplayBuffer:
+    """Fixed-size buffer to store experience tuples."""
+
+    def __init__(self, action_size, buffer_size, batch_size, seed):
+        """Initialize a ReplayBuffer object.
+
+        Params
+        ======
+            action_size (int): dimension of each action
+            buffer_size (int): maximum size of buffer
+            batch_size (int): size of each training batch
+            seed (int): random seed
+        """
+        self.action_size = action_size
+        self.memory = deque(maxlen=buffer_size)
+        self.batch_size = batch_size
+        self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
+        self.seed = random.seed(seed)
+
+    def add(self, state, action, reward, next_state, done):
+        """Add a new experience to memory."""
+        e = self.experience(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done)
+        self.memory.append(e)
+
+    def sample(self):
+        """Randomly sample a batch of experiences from memory."""
+        experiences = random.sample(self.memory, k=self.batch_size)
+
+        states = torch.from_numpy(self.__v_stack_impr([e.state for e in experiences if e is not None])) \
+            .float().to(device)
+        actions = torch.from_numpy(self.__v_stack_impr([e.action for e in experiences if e is not None])) \
+            .long().to(device)
+        rewards = torch.from_numpy(self.__v_stack_impr([e.reward for e in experiences if e is not None])) \
+            .float().to(device)
+        next_states = torch.from_numpy(self.__v_stack_impr([e.next_state for e in experiences if e is not None])) \
+            .float().to(device)
+        dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \
+            .float().to(device)
+
+        return (states, actions, rewards, next_states, dones)
+
+    def __len__(self):
+        """Return the current size of internal memory."""
+        return len(self.memory)
+
+    def __v_stack_impr(self, states):
+        sub_dim = len(states[0][0]) if isinstance(states[0], Iterable) else 1
+        np_states = np.reshape(np.array(states), (len(states), sub_dim))
+        return np_states
+
+
+import copy
+import os
+import random
+from collections import namedtuple, deque, Iterable
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.optim as optim
+
+from src.agent.model import QNetwork2, QNetwork
+
+BUFFER_SIZE = int(1e5)  # replay buffer size
+BATCH_SIZE = 512  # minibatch size
+GAMMA = 0.95  # discount factor 0.99
+TAU = 0.5e-4  # for soft update of target parameters
+LR = 0.5e-3  # learning rate 0.5e-4 works
+
+# how often to update the network
+UPDATE_EVERY = 40
+UPDATE_EVERY_FINAL = 1000
+UPDATE_EVERY_AGENT_CANT_CHOOSE = 200
+
+double_dqn = True  # If using double dqn algorithm
+input_channels = 5  # Number of Input channels
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+device = torch.device("cpu")
+print(device)
+
+USE_OPTIMIZER = optim.Adam
+# USE_OPTIMIZER = optim.RMSprop
+print(USE_OPTIMIZER)
+
+
+class Agent:
+    """Interacts with and learns from the environment."""
+
+    def __init__(self, state_size, action_size, net_type, seed, double_dqn=True, input_channels=5):
+        """Initialize an Agent object.
+
+        Params
+        ======
+            state_size (int): dimension of each state
+            action_size (int): dimension of each action
+            seed (int): random seed
+        """
+        self.state_size = state_size
+        self.action_size = action_size
+        self.seed = random.seed(seed)
+        self.version = net_type
+        self.double_dqn = double_dqn
+        # Q-Network
+        if self.version == "Conv":
+            self.qnetwork_local = QNetwork2(state_size, action_size, seed, input_channels).to(device)
+            self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
+        else:
+            self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device)
+            self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
+
+        self.optimizer = USE_OPTIMIZER(self.qnetwork_local.parameters(), lr=LR)
+
+        # Replay memory
+        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
+        self.memory_final = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
+        self.memory_agent_can_not_choose = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
+
+        self.final_step = {}
+
+        # Initialize time step (for updating every UPDATE_EVERY steps)
+        self.t_step = 0
+        self.t_step_final = 0
+        self.t_step_agent_can_not_choose = 0
+
+    def save(self, filename):
+        torch.save(self.qnetwork_local.state_dict(), filename + ".local")
+        torch.save(self.qnetwork_target.state_dict(), filename + ".target")
+
+    def load(self, filename):
+        print("try to load: " + filename)
+        if os.path.exists(filename + ".local"):
+            self.qnetwork_local.load_state_dict(torch.load(filename + ".local"))
+            print(filename + ".local -> ok")
+        if os.path.exists(filename + ".target"):
+            self.qnetwork_target.load_state_dict(torch.load(filename + ".target"))
+            print(filename + ".target -> ok")
+        self.optimizer = USE_OPTIMIZER(self.qnetwork_local.parameters(), lr=LR)
+
+    def _update_model(self, switch=0):
+        # Learn every UPDATE_EVERY time steps.
+        # If enough samples are available in memory, get random subset and learn
+        if switch == 0:
+            self.t_step = (self.t_step + 1) % UPDATE_EVERY
+            if self.t_step == 0:
+                if len(self.memory) > BATCH_SIZE:
+                    experiences = self.memory.sample()
+                    self.learn(experiences, GAMMA)
+        elif switch == 1:
+            self.t_step_final = (self.t_step_final + 1) % UPDATE_EVERY_FINAL
+            if self.t_step_final == 0:
+                if len(self.memory_final) > BATCH_SIZE:
+                    experiences = self.memory_final.sample()
+                    self.learn(experiences, GAMMA)
+        else:
+            # If enough samples are available in memory_agent_can_not_choose, get random subset and learn
+            self.t_step_agent_can_not_choose = (self.t_step_agent_can_not_choose + 1) % UPDATE_EVERY_AGENT_CANT_CHOOSE
+            if self.t_step_agent_can_not_choose == 0:
+                if len(self.memory_agent_can_not_choose) > BATCH_SIZE:
+                    experiences = self.memory_agent_can_not_choose.sample()
+                    self.learn(experiences, GAMMA)
+
+    def step(self, state, action, reward, next_state, done):
+        # Save experience in replay memory
+        self.memory.add(state, action, reward, next_state, done)
+        self._update_model(0)
+
+    def step_agent_can_not_choose(self, state, action, reward, next_state, done):
+        # Save experience in replay memory_agent_can_not_choose
+        self.memory_agent_can_not_choose.add(state, action, reward, next_state, done)
+        self._update_model(2)
+
+    def add_final_step(self, agent_handle, state, action, reward, next_state, done):
+        if self.final_step.get(agent_handle) is None:
+            self.final_step.update({agent_handle: [state, action, reward, next_state, done]})
+            return True
+        else:
+            return False
+
+    def make_final_step(self, additional_reward=0):
+        for _, item in self.final_step.items():
+            state = item[0]
+            action = item[1]
+            reward = item[2] + additional_reward
+            next_state = item[3]
+            done = item[4]
+            self.memory_final.add(state, action, reward, next_state, done)
+            self._update_model(1)
+        self._reset_final_step()
+
+    def _reset_final_step(self):
+        self.final_step = {}
+
+    def act(self, state, eps=0.):
+        """Returns actions for given state as per current policy.
+
+        Params
+        ======
+            state (array_like): current state
+            eps (float): epsilon, for epsilon-greedy action selection
+        """
+        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
+        self.qnetwork_local.eval()
+        with torch.no_grad():
+            action_values = self.qnetwork_local(state)
+        self.qnetwork_local.train()
+
+        # Epsilon-greedy action selection
+        if random.random() > eps:
+            return np.argmax(action_values.cpu().data.numpy()), False
+        else:
+            return random.choice(np.arange(self.action_size)), True
+
+    def learn(self, experiences, gamma):
+
+        """Update value parameters using given batch of experience tuples.
+
+        Params
+        ======
+            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
+            gamma (float): discount factor
+        """
+        states, actions, rewards, next_states, dones = experiences
+
+        # Get expected Q values from local model
+        Q_expected = self.qnetwork_local(states).gather(1, actions)
+
+        if self.double_dqn:
+            # Double DQN
+            q_best_action = self.qnetwork_local(next_states).max(1)[1]
+            Q_targets_next = self.qnetwork_target(next_states).gather(1, q_best_action.unsqueeze(-1))
+        else:
+            # DQN
+            Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(-1)
+
+            # Compute Q targets for current states
+
+        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
+
+        # Compute loss
+        loss = F.mse_loss(Q_expected, Q_targets)
+        # Minimize the loss
+        self.optimizer.zero_grad()
+        loss.backward()
+        self.optimizer.step()
+
+        # ------------------- update target network ------------------- #
+        self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)
+
+    def soft_update(self, local_model, target_model, tau):
+        """Soft update model parameters.
+        θ_target = τ*θ_local + (1 - τ)*θ_target
+
+        Params
+        ======
+            local_model (PyTorch model): weights will be copied from
+            target_model (PyTorch model): weights will be copied to
+            tau (float): interpolation parameter
+        """
+        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
+            target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)
+
+
+class ReplayBuffer:
+    """Fixed-size buffer to store experience tuples."""
+
+    def __init__(self, action_size, buffer_size, batch_size, seed):
+        """Initialize a ReplayBuffer object.
+
+        Params
+        ======
+            action_size (int): dimension of each action
+            buffer_size (int): maximum size of buffer
+            batch_size (int): size of each training batch
+            seed (int): random seed
+        """
+        self.action_size = action_size
+        self.memory = deque(maxlen=buffer_size)
+        self.batch_size = batch_size
+        self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
+        self.seed = random.seed(seed)
+
+    def add(self, state, action, reward, next_state, done):
+        """Add a new experience to memory."""
+        e = self.experience(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done)
+        self.memory.append(e)
+
+    def sample(self):
+        """Randomly sample a batch of experiences from memory."""
+        experiences = random.sample(self.memory, k=self.batch_size)
+
+        states = torch.from_numpy(self.__v_stack_impr([e.state for e in experiences if e is not None])) \
+            .float().to(device)
+        actions = torch.from_numpy(self.__v_stack_impr([e.action for e in experiences if e is not None])) \
+            .long().to(device)
+        rewards = torch.from_numpy(self.__v_stack_impr([e.reward for e in experiences if e is not None])) \
+            .float().to(device)
+        next_states = torch.from_numpy(self.__v_stack_impr([e.next_state for e in experiences if e is not None])) \
+            .float().to(device)
+        dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \
+            .float().to(device)
+
+        return (states, actions, rewards, next_states, dones)
+
+    def __len__(self):
+        """Return the current size of internal memory."""
+        return len(self.memory)
+
+    def __v_stack_impr(self, states):
+        sub_dim = len(states[0][0]) if isinstance(states[0], Iterable) else 1
+        np_states = np.reshape(np.array(states), (len(states), sub_dim))
+        return np_states
diff --git a/src/agent/model.py b/src/agent/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..70952e0e1e708620ecae294f327becd167321c62
--- /dev/null
+++ b/src/agent/model.py
@@ -0,0 +1,61 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class QNetwork(nn.Module):
+    def __init__(self, state_size, action_size, seed, hidsize1=64, hidsize2=128):
+        super(QNetwork, self).__init__()
+
+        self.fc1_val = nn.Linear(state_size, hidsize1)
+        self.fc2_val = nn.Linear(hidsize1, hidsize2)
+        self.fc3_val = nn.Linear(hidsize2, 1)
+
+        self.fc1_adv = nn.Linear(state_size, hidsize1)
+        self.fc2_adv = nn.Linear(hidsize1, hidsize2)
+        self.fc3_adv = nn.Linear(hidsize2, action_size)
+
+    def forward(self, x):
+        val = F.relu(self.fc1_val(x))
+        val = F.relu(self.fc2_val(val))
+        val = self.fc3_val(val)
+
+        # advantage calculation
+        adv = F.relu(self.fc1_adv(x))
+        adv = F.relu(self.fc2_adv(adv))
+        adv = self.fc3_adv(adv)
+        return val + adv - adv.mean()
+
+
+class QNetwork2(nn.Module):
+    def __init__(self, state_size, action_size, seed, input_channels, hidsize1=128, hidsize2=64):
+        super(QNetwork2, self).__init__()
+        self.conv1 = nn.Conv2d(input_channels, 16, kernel_size=3, stride=1)
+        self.bn1 = nn.BatchNorm2d(16)
+        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=3)
+        self.bn2 = nn.BatchNorm2d(32)
+        self.conv3 = nn.Conv2d(32, 64, kernel_size=5, stride=3)
+        self.bn3 = nn.BatchNorm2d(64)
+
+        self.fc1_val = nn.Linear(6400, hidsize1)
+        self.fc2_val = nn.Linear(hidsize1, hidsize2)
+        self.fc3_val = nn.Linear(hidsize2, 1)
+
+        self.fc1_adv = nn.Linear(6400, hidsize1)
+        self.fc2_adv = nn.Linear(hidsize1, hidsize2)
+        self.fc3_adv = nn.Linear(hidsize2, action_size)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        x = F.relu(self.conv2(x))
+        x = F.relu(self.conv3(x))
+
+        # value function approximation
+        val = F.relu(self.fc1_val(x.view(x.size(0), -1)))
+        val = F.relu(self.fc2_val(val))
+        val = self.fc3_val(val)
+
+        # advantage calculation
+        adv = F.relu(self.fc1_adv(x.view(x.size(0), -1)))
+        adv = F.relu(self.fc2_adv(adv))
+        adv = self.fc3_adv(adv)
+        return val + adv - adv.mean()
diff --git a/src/observations.py b/src/observations.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc89198440dd932f1ee474deff63a2450242106e
--- /dev/null
+++ b/src/observations.py
@@ -0,0 +1,731 @@
+"""
+Collection of environment-specific ObservationBuilder.
+"""
+import collections
+from typing import Optional, List, Dict, Tuple
+
+import numpy as np
+from flatland.core.env import Environment
+from flatland.core.env_observation_builder import ObservationBuilder
+from flatland.core.env_prediction_builder import PredictionBuilder
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.core.grid.grid_utils import coordinate_to_position
+from flatland.envs.agent_utils import RailAgentStatus, EnvAgent
+from flatland.utils.ordered_set import OrderedSet
+
+
+class MyTreeObsForRailEnv(ObservationBuilder):
+    """
+    TreeObsForRailEnv object.
+
+    This object returns observation vectors for agents in the RailEnv environment.
+    The information is local to each agent and exploits the graph structure of the rail
+    network to simplify the representation of the state of the environment for each agent.
+
+    For details about the features in the tree observation see the get() function.
+    """
+    Node = collections.namedtuple('Node', 'dist_min_to_target '
+                                          'target_encountered '
+                                          'num_agents_same_direction '
+                                          'num_agents_opposite_direction '
+                                          'childs')
+
+    tree_explored_actions_char = ['L', 'F', 'R', 'B']
+
+    def __init__(self, max_depth: int, predictor: PredictionBuilder = None):
+        super().__init__()
+        self.max_depth = max_depth
+        self.observation_dim = 2
+        self.location_has_agent = {}
+        self.predictor = predictor
+        self.location_has_target = None
+
+        self.switches_list = {}
+        self.switches_neighbours_list = []
+        self.check_agent_descision = None
+
+    def reset(self):
+        self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents}
+
+    def set_switch_and_pre_switch(self, switch_list, pre_switch_list, check_agent_descision):
+        self.switches_list = switch_list
+        self.switches_neighbours_list = pre_switch_list
+        self.check_agent_descision = check_agent_descision
+
+    def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Node]:
+        """
+        Called whenever an observation has to be computed for the `env` environment, for each agent with handle
+        in the `handles` list.
+        """
+
+        if handles is None:
+            handles = []
+        if self.predictor:
+            self.max_prediction_depth = 0
+            self.predicted_pos = {}
+            self.predicted_dir = {}
+            self.predictions = self.predictor.get()
+            if self.predictions:
+                for t in range(self.predictor.max_depth + 1):
+                    pos_list = []
+                    dir_list = []
+                    for a in handles:
+                        if self.predictions[a] is None:
+                            continue
+                        pos_list.append(self.predictions[a][t][1:3])
+                        dir_list.append(self.predictions[a][t][3])
+                    self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
+                    self.predicted_dir.update({t: dir_list})
+                self.max_prediction_depth = len(self.predicted_pos)
+        # Update local lookup table for all agents' positions
+        # ignore other agents not in the grid (only status active and done)
+
+        self.location_has_agent = {}
+        self.location_has_agent_direction = {}
+        self.location_has_agent_speed = {}
+        self.location_has_agent_malfunction = {}
+        self.location_has_agent_ready_to_depart = {}
+
+        for _agent in self.env.agents:
+            if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \
+                    _agent.position:
+                self.location_has_agent[tuple(_agent.position)] = 1
+                self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction
+                self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed']
+                self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data[
+                    'malfunction']
+
+            if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \
+                    _agent.initial_position:
+                self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \
+                    self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1
+
+        observations = super().get_many(handles)
+
+        return observations
+
+    def get(self, handle: int = 0) -> Node:
+        """
+        Computes the current observation for agent `handle` in env
+
+        The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
+        movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
+        The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for
+        the transitions. The order is::
+
+            [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']
+
+        Each branch data is organized as::
+
+            [root node information] +
+            [recursive branch data from 'left'] +
+            [... from 'forward'] +
+            [... from 'right] +
+            [... from 'back']
+
+        Each node information is composed of 9 features:
+
+        #1:
+            if own target lies on the explored branch the current distance from the agent in number of cells is stored.
+
+        #2:
+            if another agents target is detected the distance in number of cells from the agents current location\
+            is stored
+
+        #3:
+            if another agent is detected the distance in number of cells from current agent position is stored.
+
+        #4:
+            possible conflict detected
+            tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the \
+             distance in number of cells from current agent position
+
+            0 = No other agent reserve the same cell at similar time
+
+        #5:
+            if an not usable switch (for agent) is detected we store the distance.
+
+        #6:
+            This feature stores the distance in number of cells to the next branching  (current node)
+
+        #7:
+            minimum distance from node to the agent's target given the direction of the agent if this path is chosen
+
+        #8:
+            agent in the same direction
+            n = number of agents present same direction \
+                (possible future use: number of other agents in the same direction in this branch)
+            0 = no agent present same direction
+
+        #9:
+            agent in the opposite direction
+            n = number of agents present other direction than myself (so conflict) \
+                (possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
+            0 = no agent present other direction than myself
+
+        #10:
+            malfunctioning/blokcing agents
+            n = number of time steps the oberved agent remains blocked
+
+        #11:
+            slowest observed speed of an agent in same direction
+            1 if no agent is observed
+
+            min_fractional speed otherwise
+        #12:
+            number of agents ready to depart but no yet active
+
+        Missing/padding nodes are filled in with -inf (truncated).
+        Missing values in present node are filled in with +inf (truncated).
+
+
+        In case of the root node, the values are [0, 0, 0, 0, distance from agent to target, own malfunction, own speed]
+        In case the target node is reached, the values are [0, 0, 0, 0, 0].
+        """
+
+        if handle > len(self.env.agents):
+            print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
+        agent = self.env.agents[handle]  # TODO: handle being treated as index
+
+        if agent.status == RailAgentStatus.READY_TO_DEPART:
+            agent_virtual_position = agent.initial_position
+        elif agent.status == RailAgentStatus.ACTIVE:
+            agent_virtual_position = agent.position
+        elif agent.status == RailAgentStatus.DONE:
+            agent_virtual_position = agent.target
+        else:
+            return None
+
+        possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
+        num_transitions = np.count_nonzero(possible_transitions)
+
+        # Here information about the agent itself is stored
+        distance_map = self.env.distance_map.get()
+
+        root_node_observation = MyTreeObsForRailEnv.Node(dist_min_to_target=distance_map[
+            (handle, *agent_virtual_position,
+             agent.direction)],
+                                                         target_encountered=0,
+                                                         num_agents_same_direction=0,
+                                                         num_agents_opposite_direction=0,
+                                                         childs={})
+
+        visited = OrderedSet()
+
+        # Start from the current orientation, and see which transitions are available;
+        # organize them as [left, forward, right, back], relative to the current orientation
+        # If only one transition is possible, the tree is oriented with this transition as the forward branch.
+        orientation = agent.direction
+
+        if num_transitions == 1:
+            orientation = np.argmax(possible_transitions)
+
+        for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]):
+            if possible_transitions[branch_direction]:
+                new_cell = get_new_position(agent_virtual_position, branch_direction)
+
+                branch_observation, branch_visited = \
+                    self._explore_branch(handle, new_cell, branch_direction, 1, 1)
+                root_node_observation.childs[self.tree_explored_actions_char[i]] = branch_observation
+
+                visited |= branch_visited
+            else:
+                # add cells filled with infinity if no transition is possible
+                root_node_observation.childs[self.tree_explored_actions_char[i]] = -np.inf
+        self.env.dev_obs_dict[handle] = visited
+
+        return root_node_observation
+
+    def _explore_branch(self, handle, position, direction, tot_dist, depth):
+        """
+        Utility function to compute tree-based observations.
+        We walk along the branch and collect the information documented in the get() function.
+        If there is a branching point a new node is created and each possible branch is explored.
+        """
+
+        # [Recursive branch opened]
+        if depth >= self.max_depth + 1:
+            return [], []
+
+        # Continue along direction until next switch or
+        # until no transitions are possible along the current direction (i.e., dead-ends)
+        # We treat dead-ends as nodes, instead of going back, to avoid loops
+        exploring = True
+
+        visited = OrderedSet()
+        agent = self.env.agents[handle]
+
+        other_agent_opposite_direction = 0
+        other_agent_same_direction = 0
+
+        dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction]
+
+        last_is_dead_end = False
+        last_is_a_decision_cell = False
+        target_encountered = 0
+
+        while exploring:
+
+            dist_min_to_target = min(dist_min_to_target, self.env.distance_map.get()[handle, position[0], position[1],
+                                                                                     direction])
+
+            if agent.target == position:
+                target_encountered = 1
+
+            new_direction_me = direction
+            new_cell_me = position
+            a = self.env.agent_positions[new_cell_me]
+            if a != -1 and a != handle:
+                opp_agent = self.env.agents[a]
+                # look one step forward
+                # opp_possible_transitions = self.env.rail.get_transitions(*opp_agent.position, opp_agent.direction)
+                if opp_agent.direction != new_direction_me:  # opp_possible_transitions[new_direction_me] == 0:
+                    other_agent_opposite_direction += 1
+                else:
+                    other_agent_same_direction += 1
+
+            # #############################
+            # #############################
+            if (position[0], position[1], direction) in visited:
+                break
+            visited.add((position[0], position[1], direction))
+
+            # If the target node is encountered, pick that as node. Also, no further branching is possible.
+            if np.array_equal(position, self.env.agents[handle].target):
+                last_is_target = True
+                break
+
+            exploring = False
+
+            # Check number of possible transitions for agent and total number of transitions in cell (type)
+            possible_transitions = self.env.rail.get_transitions(*position, direction)
+            num_transitions = np.count_nonzero(possible_transitions)
+            # cell_transitions = self.env.rail.get_transitions(*position, direction)
+            transition_bit = bin(self.env.rail.get_full_transitions(*position))
+            total_transitions = transition_bit.count("1")
+
+            if num_transitions == 1:
+                # Check if dead-end, or if we can go forward along direction
+                nbits = total_transitions
+                if nbits == 1:
+                    # Dead-end!
+                    last_is_dead_end = True
+
+            if self.check_agent_descision is not None:
+                ret_agents_on_switch, ret_agents_near_to_switch, agents_near_to_switch_all = \
+                    self.check_agent_descision(position,
+                                               direction,
+                                               self.switches_list,
+                                               self.switches_neighbours_list)
+                if ret_agents_on_switch:
+                    last_is_a_decision_cell = True
+                    break
+
+            exploring = True
+            # convert one-hot encoding to 0,1,2,3
+            cell_transitions = self.env.rail.get_transitions(*position, direction)
+            direction = np.argmax(cell_transitions)
+            position = get_new_position(position, direction)
+
+        # #############################
+        # #############################
+        # Modify here to append new / different features for each visited cell!
+
+        node = MyTreeObsForRailEnv.Node(dist_min_to_target=dist_min_to_target,
+                                        target_encountered=target_encountered,
+                                        num_agents_opposite_direction=other_agent_opposite_direction,
+                                        num_agents_same_direction=other_agent_same_direction,
+                                        childs={})
+
+        # #############################
+        # #############################
+        # Start from the current orientation, and see which transitions are available;
+        # organize them as [left, forward, right, back], relative to the current orientation
+        # Get the possible transitions
+        possible_transitions = self.env.rail.get_transitions(*position, direction)
+
+        for i, branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 3)]):
+            if last_is_dead_end and self.env.rail.get_transition((*position, direction),
+                                                                 (branch_direction + 2) % 4):
+                # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
+                # it back
+                new_cell = get_new_position(position, (branch_direction + 2) % 4)
+                branch_observation, branch_visited = self._explore_branch(handle,
+                                                                          new_cell,
+                                                                          (branch_direction + 2) % 4,
+                                                                          tot_dist + 1,
+                                                                          depth + 1)
+                node.childs[self.tree_explored_actions_char[i]] = branch_observation
+                if len(branch_visited) != 0:
+                    visited |= branch_visited
+            elif last_is_a_decision_cell and possible_transitions[branch_direction]:
+                new_cell = get_new_position(position, branch_direction)
+                branch_observation, branch_visited = self._explore_branch(handle,
+                                                                          new_cell,
+                                                                          branch_direction,
+                                                                          tot_dist + 1,
+                                                                          depth + 1)
+                node.childs[self.tree_explored_actions_char[i]] = branch_observation
+                if len(branch_visited) != 0:
+                    visited |= branch_visited
+            else:
+                # no exploring possible, add just cells with infinity
+                node.childs[self.tree_explored_actions_char[i]] = -np.inf
+
+        if depth == self.max_depth:
+            node.childs.clear()
+        return node, visited
+
+    def util_print_obs_subtree(self, tree: Node):
+        """
+        Utility function to print tree observations returned by this object.
+        """
+        self.print_node_features(tree, "root", "")
+        for direction in self.tree_explored_actions_char:
+            self.print_subtree(tree.childs[direction], direction, "\t")
+
+    @staticmethod
+    def print_node_features(node: Node, label, indent):
+        print(indent, "Direction ", label, ": ", node.num_agents_same_direction,
+              ", ", node.num_agents_opposite_direction)
+
+    def print_subtree(self, node, label, indent):
+        if node == -np.inf or not node:
+            print(indent, "Direction ", label, ": -np.inf")
+            return
+
+        self.print_node_features(node, label, indent)
+
+        if not node.childs:
+            return
+
+        for direction in self.tree_explored_actions_char:
+            self.print_subtree(node.childs[direction], direction, indent + "\t")
+
+    def set_env(self, env: Environment):
+        super().set_env(env)
+        if self.predictor:
+            self.predictor.set_env(self.env)
+
+    def _reverse_dir(self, direction):
+        return int((direction + 2) % 4)
+
+
+class GlobalObsForRailEnv(ObservationBuilder):
+    """
+    Gives a global observation of the entire rail environment.
+    The observation is composed of the following elements:
+
+        - transition map array with dimensions (env.height, env.width, 16),\
+          assuming 16 bits encoding of transitions.
+
+        - obs_agents_state: A 3D array (map_height, map_width, 5) with
+            - first channel containing the agents position and direction
+            - second channel containing the other agents positions and direction
+            - third channel containing agent/other agent malfunctions
+            - fourth channel containing agent/other agent fractional speeds
+            - fifth channel containing number of other agents ready to depart
+
+        - obs_targets: Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\
+         target and the positions of the other agents targets (flag only, no counter!).
+    """
+
+    def __init__(self):
+        super(GlobalObsForRailEnv, self).__init__()
+
+    def set_env(self, env: Environment):
+        super().set_env(env)
+
+    def reset(self):
+        self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
+        for i in range(self.rail_obs.shape[0]):
+            for j in range(self.rail_obs.shape[1]):
+                bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
+                bitlist = [0] * (16 - len(bitlist)) + bitlist
+                self.rail_obs[i, j] = np.array(bitlist)
+
+    def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
+
+        agent = self.env.agents[handle]
+        if agent.status == RailAgentStatus.READY_TO_DEPART:
+            agent_virtual_position = agent.initial_position
+        elif agent.status == RailAgentStatus.ACTIVE:
+            agent_virtual_position = agent.position
+        elif agent.status == RailAgentStatus.DONE:
+            agent_virtual_position = agent.target
+        else:
+            return None
+
+        obs_targets = np.zeros((self.env.height, self.env.width, 2))
+        obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - 1
+
+        # TODO can we do this more elegantly?
+        # for r in range(self.env.height):
+        #     for c in range(self.env.width):
+        #         obs_agents_state[(r, c)][4] = 0
+        obs_agents_state[:, :, 4] = 0
+
+        obs_agents_state[agent_virtual_position][0] = agent.direction
+        obs_targets[agent.target][0] = 1
+
+        for i in range(len(self.env.agents)):
+            other_agent: EnvAgent = self.env.agents[i]
+
+            # ignore other agents not in the grid any more
+            if other_agent.status == RailAgentStatus.DONE_REMOVED:
+                continue
+
+            obs_targets[other_agent.target][1] = 1
+
+            # second to fourth channel only if in the grid
+            if other_agent.position is not None:
+                # second channel only for other agents
+                if i != handle:
+                    obs_agents_state[other_agent.position][1] = other_agent.direction
+                obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction']
+                obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed']
+            # fifth channel: all ready to depart on this position
+            if other_agent.status == RailAgentStatus.READY_TO_DEPART:
+                obs_agents_state[other_agent.initial_position][4] += 1
+        return self.rail_obs, obs_agents_state, obs_targets
+
+
+class LocalObsForRailEnv(ObservationBuilder):
+    """
+    !!!!!!WARNING!!! THIS IS DEPRACTED AND NOT UPDATED TO FLATLAND 2.0!!!!!
+    Gives a local observation of the rail environment around the agent.
+    The observation is composed of the following elements:
+
+        - transition map array of the local environment around the given agent, \
+          with dimensions (view_height,2*view_width+1, 16), \
+          assuming 16 bits encoding of transitions.
+
+        - Two 2D arrays (view_height,2*view_width+1, 2) containing respectively, \
+        if they are in the agent's vision range, its target position, the positions of the other targets.
+
+        - A 2D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions \
+          of the other agents at their position coordinates, if they are in the agent's vision range.
+
+        - A 4 elements array with one hot encoding of the direction.
+
+    Use the parameters view_width and view_height to define the rectangular view of the agent.
+    The center parameters moves the agent along the height axis of this rectangle. If it is 0 the agent only has
+    observation in front of it.
+
+    .. deprecated:: 2.0.0
+    """
+
+    def __init__(self, view_width, view_height, center):
+
+        super(LocalObsForRailEnv, self).__init__()
+        self.view_width = view_width
+        self.view_height = view_height
+        self.center = center
+        self.max_padding = max(self.view_width, self.view_height - self.center)
+
+    def reset(self):
+        # We build the transition map with a view_radius empty cells expansion on each side.
+        # This helps to collect the local transition map view when the agent is close to a border.
+        self.max_padding = max(self.view_width, self.view_height)
+        self.rail_obs = np.zeros((self.env.height,
+                                  self.env.width, 16))
+        for i in range(self.env.height):
+            for j in range(self.env.width):
+                bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
+                bitlist = [0] * (16 - len(bitlist)) + bitlist
+                self.rail_obs[i, j] = np.array(bitlist)
+
+    def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray):
+        agents = self.env.agents
+        agent = agents[handle]
+
+        # Correct agents position for padding
+        # agent_rel_pos[0] = agent.position[0] + self.max_padding
+        # agent_rel_pos[1] = agent.position[1] + self.max_padding
+
+        # Collect visible cells as set to be plotted
+        visited, rel_coords = self.field_of_view(agent.position, agent.direction, )
+        local_rail_obs = None
+
+        # Add the visible cells to the observed cells
+        self.env.dev_obs_dict[handle] = set(visited)
+
+        # Locate observed agents and their coresponding targets
+        local_rail_obs = np.zeros((self.view_height, 2 * self.view_width + 1, 16))
+        obs_map_state = np.zeros((self.view_height, 2 * self.view_width + 1, 2))
+        obs_other_agents_state = np.zeros((self.view_height, 2 * self.view_width + 1, 4))
+        _idx = 0
+        for pos in visited:
+            curr_rel_coord = rel_coords[_idx]
+            local_rail_obs[curr_rel_coord[0], curr_rel_coord[1], :] = self.rail_obs[pos[0], pos[1], :]
+            if pos == agent.target:
+                obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 0] = 1
+            else:
+                for tmp_agent in agents:
+                    if pos == tmp_agent.target:
+                        obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 1] = 1
+            if pos != agent.position:
+                for tmp_agent in agents:
+                    if pos == tmp_agent.position:
+                        obs_other_agents_state[curr_rel_coord[0], curr_rel_coord[1], :] = np.identity(4)[
+                            tmp_agent.direction]
+
+            _idx += 1
+
+        direction = np.identity(4)[agent.direction]
+        return local_rail_obs, obs_map_state, obs_other_agents_state, direction
+
+    def get_many(self, handles: Optional[List[int]] = None) -> Dict[
+        int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]:
+        """
+        Called whenever an observation has to be computed for the `env` environment, for each agent with handle
+        in the `handles` list.
+        """
+
+        return super().get_many(handles)
+
+    def field_of_view(self, position, direction, state=None):
+        # Compute the local field of view for an agent in the environment
+        data_collection = False
+        if state is not None:
+            temp_visible_data = np.zeros(shape=(self.view_height, 2 * self.view_width + 1, 16))
+            data_collection = True
+        if direction == 0:
+            origin = (position[0] + self.center, position[1] - self.view_width)
+        elif direction == 1:
+            origin = (position[0] - self.view_width, position[1] - self.center)
+        elif direction == 2:
+            origin = (position[0] - self.center, position[1] + self.view_width)
+        else:
+            origin = (position[0] + self.view_width, position[1] + self.center)
+        visible = list()
+        rel_coords = list()
+        for h in range(self.view_height):
+            for w in range(2 * self.view_width + 1):
+                if direction == 0:
+                    if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
+                        visible.append((origin[0] - h, origin[1] + w))
+                        rel_coords.append((h, w))
+                    # if data_collection:
+                    #    temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :]
+                elif direction == 1:
+                    if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width:
+                        visible.append((origin[0] + w, origin[1] + h))
+                        rel_coords.append((h, w))
+                    # if data_collection:
+                    #    temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :]
+                elif direction == 2:
+                    if 0 <= origin[0] + h < self.env.height and 0 <= origin[1] - w < self.env.width:
+                        visible.append((origin[0] + h, origin[1] - w))
+                        rel_coords.append((h, w))
+                    # if data_collection:
+                    #    temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :]
+                else:
+                    if 0 <= origin[0] - w < self.env.height and 0 <= origin[1] - h < self.env.width:
+                        visible.append((origin[0] - w, origin[1] - h))
+                        rel_coords.append((h, w))
+                    # if data_collection:
+                    #    temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :]
+        if data_collection:
+            return temp_visible_data
+        else:
+            return visible, rel_coords
+
+
+def _split_node_into_feature_groups(node: MyTreeObsForRailEnv.Node, dist_min_to_target: int) -> (np.ndarray, np.ndarray,
+                                                                                                 np.ndarray):
+    data = np.zeros(2)
+
+    data[0] = 2.0 * int(node.num_agents_opposite_direction > 0) - 1.0
+    # data[1] = 2.0 * int(node.num_agents_same_direction > 0) - 1.0
+    data[1] = 2.0 * int(node.target_encountered > 0) - 1.0
+
+    return data
+
+
+def _split_subtree_into_feature_groups(node: MyTreeObsForRailEnv.Node, dist_min_to_target: int,
+                                       current_tree_depth: int,
+                                       max_tree_depth: int) -> (
+        np.ndarray, np.ndarray, np.ndarray):
+    if node == -np.inf:
+        remaining_depth = max_tree_depth - current_tree_depth
+        # reference: https://stackoverflow.com/questions/515214/total-number-of-nodes-in-a-tree-data-structure
+        num_remaining_nodes = int((4 ** (remaining_depth + 1) - 1) / (4 - 1))
+        return [0] * num_remaining_nodes * 2
+
+    data = _split_node_into_feature_groups(node, dist_min_to_target)
+
+    if not node.childs:
+        return data
+
+    for direction in MyTreeObsForRailEnv.tree_explored_actions_char:
+        sub_data = _split_subtree_into_feature_groups(node.childs[direction],
+                                                      node.dist_min_to_target,
+                                                      current_tree_depth + 1,
+                                                      max_tree_depth)
+        data = np.concatenate((data, sub_data))
+    return data
+
+
+def split_tree_into_feature_groups(tree: MyTreeObsForRailEnv.Node, max_tree_depth: int) -> (
+        np.ndarray, np.ndarray, np.ndarray):
+    """
+    This function splits the tree into three difference arrays of values
+    """
+    data = _split_node_into_feature_groups(tree, 1000000.0)
+
+    for direction in MyTreeObsForRailEnv.tree_explored_actions_char:
+        sub_data = _split_subtree_into_feature_groups(tree.childs[direction],
+                                                      1000000.0,
+                                                      1,
+                                                      max_tree_depth)
+        data = np.concatenate((data, sub_data))
+
+    return data
+
+
+def normalize_observation(observation: MyTreeObsForRailEnv.Node, tree_depth: int):
+    """
+    This function normalizes the observation used by the RL algorithm
+    """
+    data = split_tree_into_feature_groups(observation, tree_depth)
+    normalized_obs = data
+
+    # navigate_info
+    navigate_info = np.zeros(4)
+    action_info = np.zeros(4)
+    np.seterr(all='raise')
+    try:
+        dm = observation.dist_min_to_target
+        if observation.childs['L'] != -np.inf:
+            navigate_info[0] = dm - observation.childs['L'].dist_min_to_target
+            action_info[0] = 1
+        if observation.childs['F'] != -np.inf:
+            navigate_info[1] = dm - observation.childs['F'].dist_min_to_target
+            action_info[1] = 1
+        if observation.childs['R'] != -np.inf:
+            navigate_info[2] = dm - observation.childs['R'].dist_min_to_target
+            action_info[2] = 1
+        if observation.childs['B'] != -np.inf:
+            navigate_info[3] = dm - observation.childs['B'].dist_min_to_target
+            action_info[3] = 1
+    except:
+        navigate_info = np.ones(4)
+        normalized_obs = np.zeros(len(normalized_obs))
+
+    # navigate_info_2 = np.copy(navigate_info)
+    # max_v = np.max(navigate_info_2)
+    # navigate_info_2 = navigate_info_2 / max_v
+    # navigate_info_2[navigate_info_2 < 1] = -1
+
+    max_v = np.max(navigate_info)
+    navigate_info = navigate_info / max_v
+    navigate_info[navigate_info < 0] = -1
+    # navigate_info[abs(navigate_info) < 1] = 0
+    # normalized_obs = navigate_info
+
+    # navigate_info = np.concatenate((navigate_info, action_info))
+    normalized_obs = np.concatenate((navigate_info, normalized_obs))
+    # normalized_obs = np.concatenate((navigate_info, navigate_info_2))
+    # print(normalized_obs)
+    return normalized_obs