From 44844ffa3d34c93520cae3632edf2f6126254a1d Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Mon, 7 Oct 2019 16:27:24 -0400 Subject: [PATCH] added shortest path predictor to multi agent training and inference --- torch_training/multi_agent_inference.py | 14 +++++++------- torch_training/multi_agent_training.py | 4 +++- torch_training/render_agent_behavior.py | 2 -- torch_training/training_navigation.py | 4 ++-- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py index f2458c2..1866443 100644 --- a/torch_training/multi_agent_inference.py +++ b/torch_training/multi_agent_inference.py @@ -30,7 +30,7 @@ y_dim = env.height # Parameters for the Environment x_dim = 25 y_dim = 25 -n_agents = 1 +n_agents = 10 # We are training an Agent using the Tree Observation with depth 2 observation_builder = TreeObsForRailEnv(max_depth=2) @@ -43,13 +43,13 @@ stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents } # Custom observation builder -TreeObservation = TreeObsForRailEnv(max_depth=2) +TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) # Different agent types (trains) with different speeds. -speed_ration_map = {1.: 1., # Fast passenger train - 1. / 2.: 0.0, # Fast freight train - 1. / 3.: 0.0, # Slow commuter train - 1. / 4.: 0.0} # Slow freight train +speed_ration_map = {1.: 0.25, # Fast passenger train + 1. / 2.: 0.25, # Fast freight train + 1. / 3.: 0.25, # Slow commuter train + 1. / 4.: 0.25} # Slow freight train env = RailEnv(width=x_dim, height=y_dim, @@ -93,7 +93,7 @@ action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size) -with path(torch_training.Nets, "avoider_checkpoint1000.pth") as file_in: +with path(torch_training.Nets, "avoider_checkpoint100.pth") as file_in: agent.qnetwork_local.load_state_dict(torch.load(file_in)) record_images = False diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index ed20ea6..0dacc9c 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -14,6 +14,8 @@ import torch from torch_training.dueling_double_dqn import Agent from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv + from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.schedule_generators import sparse_schedule_generator @@ -48,7 +50,7 @@ def main(argv): } # Custom observation builder - TreeObservation = TreeObsForRailEnv(max_depth=2) + TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) # Different agent types (trains) with different speeds. speed_ration_map = {1.: 0.25, # Fast passenger train diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py index d599bcf..62cfa49 100644 --- a/torch_training/render_agent_behavior.py +++ b/torch_training/render_agent_behavior.py @@ -4,7 +4,6 @@ from collections import deque import numpy as np import torch from flatland.envs.observations import TreeObsForRailEnv -from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.schedule_generators import sparse_schedule_generator @@ -67,7 +66,6 @@ env = RailEnv(width=x_dim, obs_builder_object=TreeObservation) env.reset(True, True) -observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()) env_renderer = RenderTool(env, gl="PILSVG", ) num_features_per_node = env.obs_builder.observation_dim diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 8e46796..b0942ee 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -13,8 +13,8 @@ import numpy as np import torch from torch_training.dueling_double_dqn import Agent -from flatland.envs.observations import TreeObsForRailEnv -from flatland.envs.rail_env import RailEnv +flatland.envs.rail_env +import RailEnv from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.utils.rendertools import RenderTool -- GitLab