diff --git a/torch_training/Nets/avoid_checkpoint15000.pth b/torch_training/Nets/avoid_checkpoint15000.pth index 14882a37a86085b137f4422b6bba75f387a2d3b5..d3081e88f97ac75641c0d94c7cf794f34d436581 100644 Binary files a/torch_training/Nets/avoid_checkpoint15000.pth and b/torch_training/Nets/avoid_checkpoint15000.pth differ diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 8c30d72010d95389dd22af06971170aa9c4b4480..f16a2c43fe9bee0a4f75d5a7ae4ad0931e1bbf7e 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -5,6 +5,8 @@ import numpy as np import torch from dueling_double_dqn import Agent from flatland.envs.generators import complex_rail_generator +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import DummyPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool @@ -43,10 +45,12 @@ env = RailEnv(width=10, height=20) env.load("./railway/complex_scene.pkl") """ -env = RailEnv(width=8, - height=8, - rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=5, min_dist=5, max_dist=99999, seed=0), - number_of_agents=1) + +env = RailEnv(width=20, + height=20, + rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv()), + number_of_agents=10) env.reset(True, True) env_renderer = RenderTool(env, gl="PILSVG") @@ -69,9 +73,9 @@ action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -#agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint1500.pth')) +agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth')) -demo = False +demo = True def max_lt(seq, val): """ @@ -143,14 +147,14 @@ for trials in range(1, n_trials + 1): score = 0 env_done = 0 # Run episode - for step in range(100): + for step in range(env.height * env.width): if demo: env_renderer.renderEnv(show=True, show_observations=False) # print(step) # Action for a in range(env.get_num_agents()): if demo: - eps = 1 + eps = 0 # action = agent.act(np.array(obs[a]), eps=eps) action = agent.act(agent_obs[a], eps=eps) action_prob[action] += 1