diff --git a/checkpoints/201014015722-1500.pth b/checkpoints/201014015722-1500.pth
deleted file mode 100644
index cbef409414195189caa8e536c185c7059e42cefe..0000000000000000000000000000000000000000
Binary files a/checkpoints/201014015722-1500.pth and /dev/null differ
diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py
index c1177b91bc621e517a8706729370f502566b80e6..6e54f3d815caad71e18d4533cd53504fc4d31602 100644
--- a/reinforcement_learning/dddqn_policy.py
+++ b/reinforcement_learning/dddqn_policy.py
@@ -22,7 +22,7 @@ class DDDQNPolicy(Policy):
         self.state_size = state_size
         self.action_size = action_size
         self.double_dqn = True
-        self.hidsize = 1
+        self.hidsize = 128
 
         if not evaluation_mode:
             self.hidsize = parameters.hidden_size
@@ -34,7 +34,7 @@ class DDDQNPolicy(Policy):
             self.gamma = parameters.gamma
             self.buffer_min_size = parameters.buffer_min_size
 
-        # Device
+            # Device
         if parameters.use_gpu and torch.cuda.is_available():
             self.device = torch.device("cuda:0")
             # print("🐇 Using GPU")
@@ -43,7 +43,8 @@ class DDDQNPolicy(Policy):
             # print("🐢 Using CPU")
 
         # Q-Network
-        self.qnetwork_local = DuelingQNetwork(state_size, action_size, hidsize1=self.hidsize, hidsize2=self.hidsize).to(self.device)
+        self.qnetwork_local = DuelingQNetwork(state_size, action_size, hidsize1=self.hidsize, hidsize2=self.hidsize).to(
+            self.device)
 
         if not evaluation_mode:
             self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
@@ -119,15 +120,22 @@ class DDDQNPolicy(Policy):
         torch.save(self.qnetwork_target.state_dict(), filename + ".target")
 
     def load(self, filename):
-        if os.path.exists(filename + ".local") and os.path.exists(filename + ".target"):
-            self.qnetwork_local.load_state_dict(torch.load(filename + ".local"))
-            self.qnetwork_target.load_state_dict(torch.load(filename + ".target"))
-        else:
-            if os.path.exists(filename):
-                self.qnetwork_local.load_state_dict(torch.load(filename))
-                self.qnetwork_target.load_state_dict(torch.load(filename))
+        try:
+            if os.path.exists(filename + ".local") and os.path.exists(filename + ".target"):
+                self.qnetwork_local.load_state_dict(torch.load(filename + ".local"))
+                print("qnetwork_local loaded ('{}')".format(filename + ".local"))
+                if self.evaluation_mode:
+                    self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
+                else:
+                    self.qnetwork_target.load_state_dict(torch.load(filename + ".target"))
+                    print("qnetwork_target loaded ('{}' )".format(filename + ".target"))
             else:
-                raise FileNotFoundError("Couldn't load policy from: '{}', '{}'".format(filename + ".local", filename + ".target"))
+                print(">> Checkpoint not found, using untrained policy! ('{}', '{}')".format(filename + ".local",
+                                                                                             filename + ".target"))
+        except Exception as exc:
+            print(exc)
+            print("Couldn't load policy from, using untrained policy! ('{}', '{}')".format(filename + ".local",
+                                                                                           filename + ".target"))
 
     def save_replay_buffer(self, filename):
         memory = self.memory.memory
diff --git a/run.py b/run.py
index 918d78cbc0b7c76a2414c0b431e83cda1d1ef043..f11c94168e9ba1856665273e716361d1a2077b50 100644
--- a/run.py
+++ b/run.py
@@ -1,25 +1,21 @@
-import os
 import sys
+import time
 from argparse import Namespace
 from pathlib import Path
 
 import numpy as np
-import time
-
-import torch
 from flatland.core.env_observation_builder import DummyObservationBuilder
-from flatland.envs.observations import TreeObsForRailEnv
-from flatland.evaluators.client import FlatlandRemoteClient
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.evaluators.client import FlatlandRemoteClient
 from flatland.evaluators.client import TimeoutException
 
 from utils.deadlock_check import check_if_all_blocked
+from utils.fast_tree_obs import FastTreeObs
 
 base_dir = Path(__file__).resolve().parent.parent
 sys.path.append(str(base_dir))
 
 from reinforcement_learning.dddqn_policy import DDDQNPolicy
-from utils.observation_utils import normalize_observation
 
 ####################################################
 # EVALUATION PARAMETERS
@@ -28,7 +24,7 @@ from utils.observation_utils import normalize_observation
 VERBOSE = True
 
 # Checkpoint to use (remember to push it!)
-checkpoint = "checkpoints/201103150429-2500.pth"
+checkpoint = "./checkpoints/201103160541-1800.pth"
 
 # Use last action cache
 USE_ACTION_CACHE = True
@@ -44,20 +40,15 @@ remote_client = FlatlandRemoteClient()
 
 # Observation builder
 predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth)
-tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth, predictor=predictor)
+tree_observation = FastTreeObs(max_depth=observation_tree_depth)
 
 # Calculates state and action sizes
-n_nodes = sum([np.power(4, i) for i in range(observation_tree_depth + 1)])
-state_size = tree_observation.observation_dim * n_nodes
+state_size = tree_observation.observation_dim
 action_size = 5
 
 # Creates the policy. No GPU on evaluation server.
 policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True)
-
-if os.path.isfile(checkpoint):
-    policy.load(checkpoint)
-else:
-    print("Checkpoint not found, using untrained policy! (path: {})".format(checkpoint))
+policy.load(checkpoint)
 
 #####################################################################
 # Main evaluation loop
@@ -124,15 +115,13 @@ while True:
                 time_start = time.time()
                 action_dict = {}
                 for agent in range(nb_agents):
-                    if observation[agent] and info['action_required'][agent]:
+                    if info['action_required'][agent]:
                         if agent in agent_last_obs and np.all(agent_last_obs[agent] == observation[agent]):
                             # cache hit
                             action = agent_last_action[agent]
                             nb_hit += 1
                         else:
-                            # otherwise, run normalization and inference
-                            norm_obs = normalize_observation(observation[agent], tree_depth=observation_tree_depth, observation_radius=observation_radius)
-                            action = policy.act(norm_obs, eps=0.0)
+                            action = policy.act(observation[agent], eps=0.0)
 
                         action_dict[agent] = action
 
@@ -163,16 +152,17 @@ while True:
             nb_agents_done = sum(done[idx] for idx in local_env.get_agent_handles())
 
             if VERBOSE or done['__all__']:
-                print("Step {}/{}\tAgents done: {}\t Obs time {:.3f}s\t Inference time {:.5f}s\t Step time {:.3f}s\t Cache hits {}\t No-ops? {}".format(
-                    str(steps).zfill(4),
-                    max_nb_steps,
-                    nb_agents_done,
-                    obs_time,
-                    agent_time,
-                    step_time,
-                    nb_hit,
-                    no_ops_mode
-                ), end="\r")
+                print(
+                    "Step {}/{}\tAgents done: {}\t Obs time {:.3f}s\t Inference time {:.5f}s\t Step time {:.3f}s\t Cache hits {}\t No-ops? {}".format(
+                        str(steps).zfill(4),
+                        max_nb_steps,
+                        nb_agents_done,
+                        obs_time,
+                        agent_time,
+                        step_time,
+                        nb_hit,
+                        no_ops_mode
+                    ), end="\r")
 
             if done['__all__']:
                 # When done['__all__'] == True, then the evaluation of this
@@ -190,7 +180,8 @@ while True:
 
     np_time_taken_by_controller = np.array(time_taken_by_controller)
     np_time_taken_per_step = np.array(time_taken_per_step)
-    print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(), np_time_taken_by_controller.std())
+    print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(),
+          np_time_taken_by_controller.std())
     print("Mean/Std of Time per Step : ", np_time_taken_per_step.mean(), np_time_taken_per_step.std())
     print("=" * 100)