diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index 0dacc9c25b9bc3e01a6b3e530fa6b081cb8d91df..2e20c63b293b355afec2be33cbd9acca209039d4 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -13,15 +13,13 @@ import numpy as np import torch from torch_training.dueling_double_dqn import Agent -from flatland.envs.observations import TreeObsForRailEnv -from flatland.envs.predictions import ShortestPathPredictorForRailEnv - from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.utils.rendertools import RenderTool from utils.observation_utils import normalize_observation - +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv def main(argv): try: @@ -37,26 +35,26 @@ def main(argv): np.random.seed(1) # Parameters for the Environment - x_dim = 40 - y_dim = 40 - n_agents = 4 + x_dim = 35 + y_dim = 35 + n_agents = 5 # Use a the malfunction generator to break agents from time to time - stochastic_data = {'prop_malfunction': 0.05, # Percentage of defective agents - 'malfunction_rate': 50, # Rate of malfunction occurence + stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents + 'malfunction_rate': 30, # Rate of malfunction occurence 'min_duration': 3, # Minimal duration of malfunction 'max_duration': 20 # Max duration of malfunction } # Custom observation builder - TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) + TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30)) # Different agent types (trains) with different speeds. - speed_ration_map = {1.: 0.25, # Fast passenger train - 1. / 2.: 0.25, # Fast freight train - 1. / 3.: 0.25, # Slow commuter train - 1. / 4.: 0.25} # Slow freight train + speed_ration_map = {1.: 0., # Fast passenger train + 1. / 2.: 1.0, # Fast freight train + 1. / 3.: 0.0, # Slow commuter train + 1. / 4.: 0.0} # Slow freight train env = RailEnv(width=x_dim, height=y_dim, @@ -120,8 +118,9 @@ def main(argv): env_renderer.reset() # Build agent specific observations for a in range(env.get_num_agents()): - agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10) - agent_obs_buffer[a] = agent_obs[a].copy() + if obs[a]: + agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10) + agent_obs_buffer[a] = agent_obs[a].copy() # Reset score and done score = 0 @@ -153,7 +152,8 @@ def main(argv): agent_obs_buffer[a] = agent_obs[a].copy() agent_action_buffer[a] = action_dict[a] - agent_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10) + if next_obs[a]: + agent_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10) score += all_rewards[a] / env.get_num_agents() @@ -192,7 +192,7 @@ def main(argv): 100 * np.mean(done_window), eps, action_prob / np.sum(action_prob))) torch.save(agent.qnetwork_local.state_dict(), - './Nets/avoider_checkpoint' + str(trials) + '.pth') + './Nets/navigator_checkpoint' + str(trials) + '.pth') action_prob = [1] * action_size # Plot overall training progress at the end