diff --git a/torch_training/Nets/avoid_checkpoint15000.pth b/torch_training/Nets/avoid_checkpoint15000.pth index 77c2ad680533b909a245d7d089402520eb55efcb..a8e352a1b89046928ac6f5474cb08b3af792961c 100644 Binary files a/torch_training/Nets/avoid_checkpoint15000.pth and b/torch_training/Nets/avoid_checkpoint15000.pth differ diff --git a/torch_training/railway/complex_scene.pkl b/torch_training/railway/complex_scene.pkl index 3225c39cd4675572ffca75776b37736ec4de7f31..b5c272477f53794d78a896c33d7c91e5b8cb0ea3 100644 Binary files a/torch_training/railway/complex_scene.pkl and b/torch_training/railway/complex_scene.pkl differ diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 4b82d637b400acbb5927ac0ee5a683f5f9c04fa4..c0433720823d2b9f7ecc6851bd0df9eaab5f8691 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -41,9 +41,8 @@ env = RailEnv(width=15, number_of_agents=1) - env = RailEnv(width=10, - height=20) + height=20, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.load("./railway/complex_scene.pkl") """ @@ -51,7 +50,7 @@ env = RailEnv(width=12, height=12, rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=5, min_dist=10, max_dist=99999, seed=0), obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - number_of_agents=3) + number_of_agents=5) env.reset(True, True) @@ -60,7 +59,7 @@ handle = env.get_agent_handles() state_size = 168 * 2 action_size = 5 -n_trials = 20000 +n_trials = 15000 max_steps = int(3 * (env.height + env.width)) eps = 1. eps_end = 0.005 @@ -76,9 +75,9 @@ action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint20000.pth')) +agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth')) -demo = False +demo = True def max_lt(seq, val): """