diff --git a/torch_training/Nets/avoid_checkpoint15000.pth b/torch_training/Nets/avoid_checkpoint15000.pth index d3081e88f97ac75641c0d94c7cf794f34d436581..77c2ad680533b909a245d7d089402520eb55efcb 100644 Binary files a/torch_training/Nets/avoid_checkpoint15000.pth and b/torch_training/Nets/avoid_checkpoint15000.pth differ diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 96593864a31cf45864b3cf8f52c29a0d5a241ef0..4b82d637b400acbb5927ac0ee5a683f5f9c04fa4 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -1,12 +1,13 @@ import random from collections import deque +import matplotlib.pyplot as plt import numpy as np import torch from dueling_double_dqn import Agent from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv -from flatland.envs.predictions import DummyPredictorForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool @@ -46,10 +47,10 @@ env = RailEnv(width=10, env.load("./railway/complex_scene.pkl") """ -env = RailEnv(width=8, - height=8, - rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=4, max_dist=99999, seed=0), - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv()), +env = RailEnv(width=12, + height=12, + rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=5, min_dist=10, max_dist=99999, seed=0), + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), number_of_agents=3) env.reset(True, True) @@ -59,8 +60,8 @@ handle = env.get_agent_handles() state_size = 168 * 2 action_size = 5 -n_trials = 15000 -max_steps = int(1.5 * (env.height + env.width)) +n_trials = 20000 +max_steps = int(3 * (env.height + env.width)) eps = 1. eps_end = 0.005 eps_decay = 0.9995 @@ -75,7 +76,7 @@ action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -# agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth')) +agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint20000.pth')) demo = False @@ -220,3 +221,5 @@ for trials in range(1, n_trials + 1): torch.save(agent.qnetwork_local.state_dict(), './Nets/avoid_checkpoint' + str(trials) + '.pth') action_prob = [1] * action_size +plt.plot(scores) +plt.show()