diff --git a/torch_training/Nets/avoid_checkpoint15000.pth b/torch_training/Nets/avoid_checkpoint15000.pth index bb681d11151a13c54c78c22ac7dd421eea45ed32..e1daf228b7f1f6b108329715c3cdbd67805e28ae 100644 Binary files a/torch_training/Nets/avoid_checkpoint15000.pth and b/torch_training/Nets/avoid_checkpoint15000.pth differ diff --git a/torch_training/Nets/avoid_checkpoint30000.pth b/torch_training/Nets/avoid_checkpoint30000.pth index b6a0782cc1899a1e799011d19b3a9afb5906467c..0e2c1b28c1655bc16c9339066b8d105282f14418 100644 Binary files a/torch_training/Nets/avoid_checkpoint30000.pth and b/torch_training/Nets/avoid_checkpoint30000.pth differ diff --git a/torch_training/Nets/avoid_checkpoint60000.pth b/torch_training/Nets/avoid_checkpoint60000.pth index 1a35def33802ce6ac4b4f5d35c0c11c7095b2927..b6f15348130b09ae8bee0adad454031fc013fabf 100644 Binary files a/torch_training/Nets/avoid_checkpoint60000.pth and b/torch_training/Nets/avoid_checkpoint60000.pth differ diff --git a/torch_training/dueling_double_dqn.py b/torch_training/dueling_double_dqn.py index 6c54e4ef0aed7d833c25f3eb516d2abcb3589eee..dd67b4f0d73ffe1b3f4ad3e947debf18508e78b0 100644 --- a/torch_training/dueling_double_dqn.py +++ b/torch_training/dueling_double_dqn.py @@ -14,7 +14,7 @@ BUFFER_SIZE = int(1e5) # replay buffer size BATCH_SIZE = 512 # minibatch size GAMMA = 0.99 # discount factor 0.99 TAU = 1e-3 # for soft update of target parameters -LR = 0.5e-4 # learning rate 5 +LR = 0.5e-4 # learning rate 0.5e-4 works UPDATE_EVERY = 10 # how often to update the network double_dqn = True # If using double dqn algorithm input_channels = 5 # Number of Input channels diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py index c14672819b1c0fed58705725da6dfb1feb1b9872..2b541219e688a8b55e20b412fd91f9e8cc22b9cb 100644 --- a/torch_training/multi_agent_inference.py +++ b/torch_training/multi_agent_inference.py @@ -9,15 +9,15 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool from importlib_resources import path - +import time import torch_training.Nets from torch_training.dueling_double_dqn import Agent from utils.observation_utils import normalize_observation random.seed(3) np.random.seed(2) -""" -file_name = "./railway/complex_scene.pkl" + +file_name = "./railway/simple_avoid.pkl" env = RailEnv(width=10, height=20, rail_generator=rail_from_file(file_name), @@ -27,21 +27,21 @@ y_dim = env.height """ -x_dim = 18 # np.random.randint(8, 20) -y_dim = 14 # np.random.randint(8, 20) -n_agents = 7 # np.random.randint(3, 8) +x_dim = 10 # np.random.randint(8, 20) +y_dim = 10 # np.random.randint(8, 20) +n_agents = 5 # np.random.randint(3, 8) n_goals = n_agents + np.random.randint(0, 3) min_dist = int(0.75 * min(x_dim, y_dim)) env = RailEnv(width=x_dim, height=y_dim, - rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, + rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist, max_dist=99999, seed=0), obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()), number_of_agents=n_agents) env.reset(True, True) - +""" tree_depth = 3 observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv()) env_renderer = RenderTool(env, gl="PILSVG", ) @@ -70,7 +70,7 @@ action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -with path(torch_training.Nets, "avoid_checkpoint46200.pth") as file_in: +with path(torch_training.Nets, "avoid_checkpoint49000.pth") as file_in: agent.qnetwork_local.load_state_dict(torch.load(file_in)) record_images = False @@ -93,14 +93,14 @@ for trials in range(1, n_trials + 1): if record_images: env_renderer.gl.save_image("./Images/Avoiding/flatland_frame_{:04d}.bmp".format(frame_step)) frame_step += 1 - + # time.sleep(5) # Action for a in range(env.get_num_agents()): action = agent.act(agent_obs[a], eps=0) action_dict.update({a: action}) # Environment step - next_obs, all_rewards, done, _ = env.step(action_dict) + for a in range(env.get_num_agents()): agent_obs[a] = normalize_observation(next_obs[a], observation_radius=10) diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index 7659b2da5b7ddaedbda39357971b6cb58efb2eb9..5e194f51b09bbe4e81e257caedf3d57d17e04bca 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -35,8 +35,8 @@ def main(argv): np.random.seed(1) # Initialize a random map with a random number of agents - x_dim = np.random.randint(8, 20) - y_dim = np.random.randint(8, 20) + x_dim = np.random.randint(8, 15) + y_dim = np.random.randint(8, 15) n_agents = np.random.randint(3, 8) n_goals = n_agents + np.random.randint(0, 3) min_dist = int(0.75 * min(x_dim, y_dim)) @@ -53,7 +53,7 @@ def main(argv): env = RailEnv(width=x_dim, height=y_dim, - rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, + rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist, max_dist=99999, seed=0), obs_builder_object=observation_helper, @@ -92,7 +92,7 @@ def main(argv): agent = Agent(state_size, action_size, "FC", 0) # Here you can pre-load an agent - if True: + if False: with path(torch_training.Nets, "avoid_checkpoint2400.pth") as file_in: agent.qnetwork_local.load_state_dict(torch.load(file_in)) @@ -103,14 +103,14 @@ def main(argv): and the size of the levels every 50 episodes. """ if episodes % 50 == 0: - x_dim = np.random.randint(8, 20) - y_dim = np.random.randint(8, 20) + x_dim = np.random.randint(8, 15) + y_dim = np.random.randint(8, 15) n_agents = np.random.randint(3, 8) n_goals = n_agents + np.random.randint(0, 3) min_dist = int(0.75 * min(x_dim, y_dim)) env = RailEnv(width=x_dim, height=y_dim, - rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, + rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist, max_dist=99999, seed=0), obs_builder_object=observation_helper, diff --git a/torch_training/railway/very_complex.pkl b/torch_training/railway/very_complex.pkl new file mode 100644 index 0000000000000000000000000000000000000000..0abbbe9674f08de518f99df46cc17b70b23b464b Binary files /dev/null and b/torch_training/railway/very_complex.pkl differ