diff --git a/torch_training/Nets/avoid_checkpoint15000.pth b/torch_training/Nets/avoid_checkpoint15000.pth index 14882a37a86085b137f4422b6bba75f387a2d3b5..d3081e88f97ac75641c0d94c7cf794f34d436581 100644 Binary files a/torch_training/Nets/avoid_checkpoint15000.pth and b/torch_training/Nets/avoid_checkpoint15000.pth differ diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 1fbe1495c0c5c11684f8a192ed3ced73d6cf2520..0c4ed8547591fe6cdd106438a4be28037baba216 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -6,6 +6,8 @@ import numpy as np import torch from flatland.envs.generators import complex_rail_generator +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import DummyPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool from torch_training.dueling_double_dqn import Agent @@ -47,10 +49,12 @@ env = RailEnv(width=10, height=20) env.load_resource('torch_training.railway', "complex_scene.pkl") """ -env = RailEnv(width=8, - height=8, - rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=5, min_dist=5, max_dist=99999, seed=0), - number_of_agents=1) + +env = RailEnv(width=20, + height=20, + rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv()), + number_of_agents=10) env.reset(True, True) env_renderer = RenderTool(env, gl="PILSVG") @@ -73,9 +77,9 @@ action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -# agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint1500.pth')) +agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth')) -demo = False +demo = True def max_lt(seq, val): @@ -149,7 +153,7 @@ for trials in range(1, n_trials + 1): score = 0 env_done = 0 # Run episode - for step in range(100): + for step in range(env.height * env.width): if demo: env_renderer.renderEnv(show=True, show_observations=False) # print(step) @@ -157,6 +161,7 @@ for trials in range(1, n_trials + 1): for a in range(env.get_num_agents()): if demo: eps = 1 + # action = agent.act(np.array(obs[a]), eps=eps) action = agent.act(agent_obs[a], eps=eps) action_prob[action] += 1 action_dict.update({a: action})