diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py
index eb11fb1b183d7041e35186f40693b378c80eec68..4524d784b0bd8b9e6a67970be36192b94984a090 100644
--- a/torch_training/multi_agent_training.py
+++ b/torch_training/multi_agent_training.py
@@ -1,16 +1,17 @@
-import random
 from collections import deque
+from sys import path
 
 import matplotlib.pyplot as plt
 import numpy as np
+import random
 import torch
 from dueling_double_dqn import Agent
+
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
-
 from utils.observation_utils import norm_obs_clip, split_tree
 
 random.seed(1)
@@ -62,7 +63,8 @@ action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size, "FC", 0)
-agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint30000.pth'))
+with path("torch_training/Nets", "avoid_checkpoint30000.pth") as file_in:
+    agent.qnetwork_local.load_state_dict(torch.load(file_in))
 
 demo = True
 record_images = False
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 3007ce59bed2e75adc0672511e397c39eea1e2cb..987c7d645ae9d574d1427fc0c686e740abdb449c 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -1,3 +1,5 @@
+from sys import path
+
 import random
 from collections import deque
 
@@ -87,7 +89,9 @@ action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size, "FC", 0)
-agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint30000.pth'))
+with path("torch_training/Nets", "avoid_checkpoint30000.pth") as file_in:
+    agent.qnetwork_local.load_state_dict(torch.load(file_in))
+
 
 demo = True
 record_images = False