diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py
index b376623e7ecde3bbfcf01a5eeb11e8a76c132cc7..f2458c20c2e47ea56e577f229f4221b1bfe4e195 100644
--- a/torch_training/multi_agent_inference.py
+++ b/torch_training/multi_agent_inference.py
@@ -15,63 +15,70 @@ import torch_training.Nets
 from torch_training.dueling_double_dqn import Agent
 from utils.observation_utils import normalize_observation
 
-random.seed(3)
-np.random.seed(2)
+random.seed(1)
+np.random.seed(1)
+"""
+file_name = "./railway/complex_scene.pkl"
+env = RailEnv(width=10,
+              height=20,
+              rail_generator=rail_from_file(file_name),
+              obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()))
+x_dim = env.width
+y_dim = env.height
+"""
+
 # Parameters for the Environment
-x_dim = 20
-y_dim = 20
-n_agents = 5
-tree_depth = 2
+x_dim = 25
+y_dim = 25
+n_agents = 1
+
+# We are training an Agent using the Tree Observation with depth 2
+observation_builder = TreeObsForRailEnv(max_depth=2)
 
 # Use a the malfunction generator to break agents from time to time
-stochastic_data = {'prop_malfunction': 0.1,  # Percentage of defective agents
+stochastic_data = {'prop_malfunction': 0.0,  # Percentage of defective agents
                    'malfunction_rate': 30,  # Rate of malfunction occurence
                    'min_duration': 3,  # Minimal duration of malfunction
                    'max_duration': 20  # Max duration of malfunction
                    }
 
 # Custom observation builder
-predictor = ShortestPathPredictorForRailEnv()
-observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=predictor)
+TreeObservation = TreeObsForRailEnv(max_depth=2)
 
 # Different agent types (trains) with different speeds.
-speed_ration_map = {1.: 0.25,  # Fast passenger train
-                    1. / 2.: 0.25,  # Fast freight train
-                    1. / 3.: 0.25,  # Slow commuter train
-                    1. / 4.: 0.25}  # Slow freight train
+speed_ration_map = {1.: 1.,  # Fast passenger train
+                    1. / 2.: 0.0,  # Fast freight train
+                    1. / 3.: 0.0,  # Slow commuter train
+                    1. / 4.: 0.0}  # Slow freight train
 
 env = RailEnv(width=x_dim,
               height=y_dim,
-              rail_generator=sparse_rail_generator(num_cities=5,
+              rail_generator=sparse_rail_generator(max_num_cities=3,
                                                    # Number of cities in map (where train stations are)
-                                                   num_intersections=4,
-                                                   # Number of intersections (no start / target)
-                                                   num_trainstations=10,  # Number of possible start/targets on map
-                                                   min_node_dist=3,  # Minimal distance of nodes
-                                                   node_radius=2,  # Proximity of stations to city center
-                                                   num_neighb=3,
-                                                   # Number of connections to other cities/intersections
-                                                   seed=15,  # Random seed
-                                                   grid_mode=True,
-                                                   enhance_intersection=False
-                                                   ),
+                                                   seed=1,  # Random seed
+                                                   grid_mode=False,
+                                                   max_rails_between_cities=2,
+                                                   max_rails_in_city=2),
               schedule_generator=sparse_schedule_generator(speed_ration_map),
               number_of_agents=n_agents,
               stochastic_data=stochastic_data,  # Malfunction data generator
-              obs_builder_object=observation_helper)
+              obs_builder_object=TreeObservation)
 env.reset(True, True)
 
+observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
 env_renderer = RenderTool(env, gl="PILSVG", )
-handle = env.get_agent_handles()
 num_features_per_node = env.obs_builder.observation_dim
+
+tree_depth = 2
 nr_nodes = 0
 for i in range(tree_depth + 1):
     nr_nodes += np.power(4, i)
 state_size = num_features_per_node * nr_nodes
 action_size = 5
 
-n_trials = 10
-observation_radius = 10
+# We set the number of episodes we would like to train on
+if 'n_trials' not in locals():
+    n_trials = 60000
 max_steps = int(3 * (env.height + env.width))
 eps = 1.
 eps_end = 0.005
@@ -80,14 +87,13 @@ action_dict = dict()
 final_action_dict = dict()
 scores_window = deque(maxlen=100)
 done_window = deque(maxlen=100)
-time_obs = deque(maxlen=2)
 scores = []
 dones_list = []
 action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size)
-with path(torch_training.Nets, "avoid_checkpoint500.pth") as file_in:
+with path(torch_training.Nets, "avoider_checkpoint1000.pth") as file_in:
     agent.qnetwork_local.load_state_dict(torch.load(file_in))
 
 record_images = False
@@ -97,29 +103,35 @@ for trials in range(1, n_trials + 1):
 
     # Reset environment
     obs, info = env.reset(True, True)
-
     env_renderer.reset()
-
+    # Build agent specific observations
     for a in range(env.get_num_agents()):
-        agent_obs[a] = normalize_observation(obs[a], observation_radius=10)
+        agent_obs[a] = agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
+    # Reset score and done
+    score = 0
+    env_done = 0
 
     # Run episode
     for step in range(max_steps):
-        env_renderer.render_env(show=True, show_observations=False, show_predictions=True)
 
-        if record_images:
-            env_renderer.gl.save_image("./Images/Avoiding/flatland_frame_{:04d}.bmp".format(frame_step))
-            frame_step += 1
-        # time.sleep(1.5)
         # Action
         for a in range(env.get_num_agents()):
-            action = agent.act(agent_obs[a], eps=0)
+            if info['action_required'][a]:
+                action = agent.act(agent_obs[a], eps=0.)
+
+            else:
+                action = 0
+
+            action_prob[action] += 1
             action_dict.update({a: action})
         # Environment step
-        next_obs, all_rewards, done, _ = env.step(action_dict)
-
+        obs, all_rewards, done, _ = env.step(action_dict)
+        env_renderer.render_env(show=True, show_predictions=True, show_observations=False)
+        # Build agent specific observations and normalize
         for a in range(env.get_num_agents()):
-            agent_obs[a] = normalize_observation(next_obs[a], observation_radius=10)
+            agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
+
 
         if done['__all__']:
             break
+