diff --git a/torch_training/Nets/avoid_checkpoint15000.pth b/torch_training/Nets/avoid_checkpoint15000.pth index 833de82752968a507b8e5397e76f55c26558d946..b82afe2e4c26bffa98cb8c35c769987033a6fa46 100644 Binary files a/torch_training/Nets/avoid_checkpoint15000.pth and b/torch_training/Nets/avoid_checkpoint15000.pth differ diff --git a/torch_training/Nets/avoid_checkpoint30000.pth b/torch_training/Nets/avoid_checkpoint30000.pth index a818af97cd8ed0eb5599c12a13748e1b2245a8cc..f1fd31ad74c61afbb3088fda64cb6e049f6ec480 100644 Binary files a/torch_training/Nets/avoid_checkpoint30000.pth and b/torch_training/Nets/avoid_checkpoint30000.pth differ diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index 4f823be331850668e18bfdf66d35a915e3f6ccdc..c9fe73308e20a113e5d393b4a16410f7611edb2b 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -6,14 +6,14 @@ from collections import deque import matplotlib.pyplot as plt import numpy as np import torch -from importlib_resources import path - -import torch_training.Nets from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool +from importlib_resources import path + +import torch_training.Nets from torch_training.dueling_double_dqn import Agent from utils.observation_utils import norm_obs_clip, split_tree @@ -32,12 +32,15 @@ def main(argv): print("main1") random.seed(1) np.random.seed(1) - """ + + file_name = "./railway/complex_scene.pkl" env = RailEnv(width=10, - height=20, obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())) - env.load("./railway/complex_scene.pkl") - file_load = True + height=20, + rail_generator=rail_from_data(file_name), + obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())) + x_dim = env.width + y_dim = env.height """ x_dim = np.random.randint(8, 20) @@ -55,7 +58,6 @@ def main(argv): obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()), number_of_agents=n_agents) env.reset(True, True) - file_load = False observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()) env_renderer = RenderTool(env, gl="PILSVG", ) @@ -68,7 +70,7 @@ def main(argv): # We set the number of episodes we would like to train on if 'n_trials' not in locals(): - n_trials = 30000 + n_trials = 60000 max_steps = int(3 * (env.height + env.width)) eps = 1. eps_end = 0.005 @@ -87,7 +89,7 @@ def main(argv): with path(torch_training.Nets, "avoid_checkpoint30000.pth") as file_in: agent.qnetwork_local.load_state_dict(torch.load(file_in)) - demo = False + demo = True record_images = False frame_step = 0 @@ -113,10 +115,7 @@ def main(argv): agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() # Reset environment - if file_load: - obs = env.reset(False, False) - else: - obs = env.reset(True, True) + obs = env.reset(True, True) if demo: env_renderer.set_new_rail() obs_original = obs.copy()