diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py index f2458c20c2e47ea56e577f229f4221b1bfe4e195..18664437ebef0dbf4261cfdb3ba692dd5fab7505 100644 --- a/torch_training/multi_agent_inference.py +++ b/torch_training/multi_agent_inference.py @@ -30,7 +30,7 @@ y_dim = env.height # Parameters for the Environment x_dim = 25 y_dim = 25 -n_agents = 1 +n_agents = 10 # We are training an Agent using the Tree Observation with depth 2 observation_builder = TreeObsForRailEnv(max_depth=2) @@ -43,13 +43,13 @@ stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents } # Custom observation builder -TreeObservation = TreeObsForRailEnv(max_depth=2) +TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) # Different agent types (trains) with different speeds. -speed_ration_map = {1.: 1., # Fast passenger train - 1. / 2.: 0.0, # Fast freight train - 1. / 3.: 0.0, # Slow commuter train - 1. / 4.: 0.0} # Slow freight train +speed_ration_map = {1.: 0.25, # Fast passenger train + 1. / 2.: 0.25, # Fast freight train + 1. / 3.: 0.25, # Slow commuter train + 1. / 4.: 0.25} # Slow freight train env = RailEnv(width=x_dim, height=y_dim, @@ -93,7 +93,7 @@ action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size) -with path(torch_training.Nets, "avoider_checkpoint1000.pth") as file_in: +with path(torch_training.Nets, "avoider_checkpoint100.pth") as file_in: agent.qnetwork_local.load_state_dict(torch.load(file_in)) record_images = False diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index ed20ea69c2042a7ab51722df6ba553aae741d2c5..0dacc9c25b9bc3e01a6b3e530fa6b081cb8d91df 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -14,6 +14,8 @@ import torch from torch_training.dueling_double_dqn import Agent from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv + from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.schedule_generators import sparse_schedule_generator @@ -48,7 +50,7 @@ def main(argv): } # Custom observation builder - TreeObservation = TreeObsForRailEnv(max_depth=2) + TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) # Different agent types (trains) with different speeds. speed_ration_map = {1.: 0.25, # Fast passenger train diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py index d599bcf5ad5306ee70670bc672f20be45dbb40dd..62cfa494072d2ac3e8d164beb54f06ef202544d0 100644 --- a/torch_training/render_agent_behavior.py +++ b/torch_training/render_agent_behavior.py @@ -4,7 +4,6 @@ from collections import deque import numpy as np import torch from flatland.envs.observations import TreeObsForRailEnv -from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.schedule_generators import sparse_schedule_generator @@ -67,7 +66,6 @@ env = RailEnv(width=x_dim, obs_builder_object=TreeObservation) env.reset(True, True) -observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()) env_renderer = RenderTool(env, gl="PILSVG", ) num_features_per_node = env.obs_builder.observation_dim diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 8e467968bcdad4a7b6ce39b2d70d096e32d332dd..b0942ee84a63f6b0d97212fb53c59377f8e7d285 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -13,8 +13,8 @@ import numpy as np import torch from torch_training.dueling_double_dqn import Agent -from flatland.envs.observations import TreeObsForRailEnv -from flatland.envs.rail_env import RailEnv +flatland.envs.rail_env +import RailEnv from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.utils.rendertools import RenderTool