diff --git a/torch_training/Nets/avoid_checkpoint15000.pth b/torch_training/Nets/avoid_checkpoint15000.pth
index 833de82752968a507b8e5397e76f55c26558d946..b82afe2e4c26bffa98cb8c35c769987033a6fa46 100644
Binary files a/torch_training/Nets/avoid_checkpoint15000.pth and b/torch_training/Nets/avoid_checkpoint15000.pth differ
diff --git a/torch_training/Nets/avoid_checkpoint30000.pth b/torch_training/Nets/avoid_checkpoint30000.pth
index a818af97cd8ed0eb5599c12a13748e1b2245a8cc..f1fd31ad74c61afbb3088fda64cb6e049f6ec480 100644
Binary files a/torch_training/Nets/avoid_checkpoint30000.pth and b/torch_training/Nets/avoid_checkpoint30000.pth differ
diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py
index 4f823be331850668e18bfdf66d35a915e3f6ccdc..c9fe73308e20a113e5d393b4a16410f7611edb2b 100644
--- a/torch_training/multi_agent_training.py
+++ b/torch_training/multi_agent_training.py
@@ -6,14 +6,14 @@ from collections import deque
 import matplotlib.pyplot as plt
 import numpy as np
 import torch
-from importlib_resources import path
-
-import torch_training.Nets
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
+from importlib_resources import path
+
+import torch_training.Nets
 from torch_training.dueling_double_dqn import Agent
 from utils.observation_utils import norm_obs_clip, split_tree
 
@@ -32,12 +32,15 @@ def main(argv):
     print("main1")
     random.seed(1)
     np.random.seed(1)
-
     """
+
+    file_name = "./railway/complex_scene.pkl"
     env = RailEnv(width=10,
-                  height=20, obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()))
-    env.load("./railway/complex_scene.pkl")
-    file_load = True
+                  height=20,
+                  rail_generator=rail_from_data(file_name),
+                  obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()))
+    x_dim = env.width
+    y_dim = env.height
     """
 
     x_dim = np.random.randint(8, 20)
@@ -55,7 +58,6 @@ def main(argv):
                   obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()),
                   number_of_agents=n_agents)
     env.reset(True, True)
-    file_load = False
 
     observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
     env_renderer = RenderTool(env, gl="PILSVG", )
@@ -68,7 +70,7 @@ def main(argv):
 
     # We set the number of episodes we would like to train on
     if 'n_trials' not in locals():
-        n_trials = 30000
+        n_trials = 60000
     max_steps = int(3 * (env.height + env.width))
     eps = 1.
     eps_end = 0.005
@@ -87,7 +89,7 @@ def main(argv):
     with path(torch_training.Nets, "avoid_checkpoint30000.pth") as file_in:
         agent.qnetwork_local.load_state_dict(torch.load(file_in))
 
-    demo = False
+    demo = True
     record_images = False
     frame_step = 0
 
@@ -113,10 +115,7 @@ def main(argv):
             agent_obs = [None] * env.get_num_agents()
             agent_next_obs = [None] * env.get_num_agents()
         # Reset environment
-        if file_load:
-            obs = env.reset(False, False)
-        else:
-            obs = env.reset(True, True)
+        obs = env.reset(True, True)
         if demo:
             env_renderer.set_new_rail()
         obs_original = obs.copy()