diff --git a/torch_training/Nets/avoid_checkpoint15000.pth b/torch_training/Nets/avoid_checkpoint15000.pth
index b82afe2e4c26bffa98cb8c35c769987033a6fa46..9cc03d3c2ac88d946ddd33fa1009cb3ab56a7b59 100644
Binary files a/torch_training/Nets/avoid_checkpoint15000.pth and b/torch_training/Nets/avoid_checkpoint15000.pth differ
diff --git a/torch_training/Nets/avoid_checkpoint30000.pth b/torch_training/Nets/avoid_checkpoint30000.pth
index f1fd31ad74c61afbb3088fda64cb6e049f6ec480..2625b7648ec3ff8e3efba2ed33eebe516654c252 100644
Binary files a/torch_training/Nets/avoid_checkpoint30000.pth and b/torch_training/Nets/avoid_checkpoint30000.pth differ
diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py
new file mode 100644
index 0000000000000000000000000000000000000000..f98318e704d776d47937b50e7e452eab51355ee9
--- /dev/null
+++ b/torch_training/render_agent_behavior.py
@@ -0,0 +1,131 @@
+import random
+from collections import deque
+
+import numpy as np
+import torch
+from flatland.envs.generators import rail_from_file
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv
+from flatland.utils.rendertools import RenderTool
+from importlib_resources import path
+
+import torch_training.Nets
+from torch_training.dueling_double_dqn import Agent
+from utils.observation_utils import norm_obs_clip, split_tree
+
+random.seed(1)
+np.random.seed(1)
+
+file_name = "./railway/complex_scene.pkl"
+env = RailEnv(width=10,
+              height=20,
+              rail_generator=rail_from_file(file_name),
+              obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()))
+x_dim = env.width
+y_dim = env.height
+"""
+
+x_dim = np.random.randint(8, 20)
+y_dim = np.random.randint(8, 20)
+n_agents = np.random.randint(3, 8)
+n_goals = n_agents + np.random.randint(0, 3)
+min_dist = int(0.75 * min(x_dim, y_dim))
+
+env = RailEnv(width=x_dim,
+              height=y_dim,
+              rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
+                                                    max_dist=99999,
+                                                    seed=0),
+              obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()),
+              number_of_agents=n_agents)
+env.reset(True, True)
+"""
+observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
+env_renderer = RenderTool(env, gl="PILSVG", )
+handle = env.get_agent_handles()
+features_per_node = 9
+state_size = features_per_node * 85 * 2
+action_size = 5
+
+# We set the number of episodes we would like to train on
+if 'n_trials' not in locals():
+    n_trials = 60000
+max_steps = int(3 * (env.height + env.width))
+eps = 1.
+eps_end = 0.005
+eps_decay = 0.9995
+action_dict = dict()
+final_action_dict = dict()
+scores_window = deque(maxlen=100)
+done_window = deque(maxlen=100)
+time_obs = deque(maxlen=2)
+scores = []
+dones_list = []
+action_prob = [0] * action_size
+agent_obs = [None] * env.get_num_agents()
+agent_next_obs = [None] * env.get_num_agents()
+agent = Agent(state_size, action_size, "FC", 0)
+with path(torch_training.Nets, "avoid_checkpoint60000.pth") as file_in:
+    agent.qnetwork_local.load_state_dict(torch.load(file_in))
+
+record_images = False
+frame_step = 0
+
+for trials in range(1, n_trials + 1):
+
+    # Reset environment
+    obs = env.reset(True, True)
+
+    env_renderer.set_new_rail()
+    obs_original = obs.copy()
+    final_obs = obs.copy()
+    final_obs_next = obs.copy()
+    for a in range(env.get_num_agents()):
+        data, distance, agent_data = split_tree(tree=np.array(obs[a]),
+                                                current_depth=0)
+        data = norm_obs_clip(data)
+        distance = norm_obs_clip(distance)
+        agent_data = np.clip(agent_data, -1, 1)
+        obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
+        agent_data = env.agents[a]
+        speed = 1  # np.random.randint(1,5)
+        agent_data.speed_data['speed'] = 1. / speed
+
+    for i in range(2):
+        time_obs.append(obs)
+    # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
+    for a in range(env.get_num_agents()):
+        agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
+
+    # Run episode
+    for step in range(max_steps):
+        env_renderer.render_env(show=True, show_observations=False, show_predictions=True)
+
+        if record_images:
+            env_renderer.gl.saveImage("./Images/flatland_frame_{:04d}.bmp".format(frame_step))
+            frame_step += 1
+
+        # Action
+        for a in range(env.get_num_agents()):
+            # action = agent.act(np.array(obs[a]), eps=eps)
+            action = agent.act(agent_obs[a], eps=0)
+            action_dict.update({a: action})
+        # Environment step
+
+        next_obs, all_rewards, done, _ = env.step(action_dict)
+        # print(all_rewards,action)
+        obs_original = next_obs.copy()
+        for a in range(env.get_num_agents()):
+            data, distance, agent_data = split_tree(tree=np.array(next_obs[a]),
+                                                    current_depth=0)
+            data = norm_obs_clip(data)
+            distance = norm_obs_clip(distance)
+            agent_data = np.clip(agent_data, -1, 1)
+            next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
+        time_obs.append(next_obs)
+        for a in range(env.get_num_agents()):
+            agent_next_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
+        agent_obs = agent_next_obs.copy()
+        if done['__all__']:
+            break
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index f575510f6e0e32eeaee2d29b7f5da0ced852fb81..6ea6b5672a6939c290d0395d7d8795d47b14b508 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -7,11 +7,11 @@ import matplotlib.pyplot as plt
 import numpy as np
 import torch
 from dueling_double_dqn import Agent
-
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
+
 from utils.observation_utils import norm_obs_clip, split_tree
 
 
@@ -52,7 +52,7 @@ def main(argv):
     env_renderer = RenderTool(env, gl="PILSVG", )
 
     # Given the depth of the tree observation and the number of features per node we get the following state_size
-    features_per_node = 9
+    features_per_node = env.obs_builder.observation_dim
     tree_depth = 2
     nr_nodes = 0
     for i in range(tree_depth + 1):
diff --git a/utils/observation_utils.py b/utils/observation_utils.py
index 4c4efa2405a01499d067e68cd1e305f40a6e11a7..121c6eb593043d81dfed61ed5b37f65eaef9af4d 100644
--- a/utils/observation_utils.py
+++ b/utils/observation_utils.py
@@ -71,7 +71,6 @@ def split_tree(tree, current_depth=0):
     :return: Returns the three different groups of distance and binary values.
     """
     num_features_per_node = TreeObsForRailEnv.observation_dim
-
     if len(tree) < num_features_per_node:
         return [], [], []