Skip to content
Snippets Groups Projects
Commit 68b09076 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

updated multi agent training for testing

parent 653126fe
No related branches found
No related tags found
No related merge requests found
...@@ -13,15 +13,13 @@ import numpy as np ...@@ -13,15 +13,13 @@ import numpy as np
import torch import torch
from torch_training.dueling_double_dqn import Agent from torch_training.dueling_double_dqn import Agent
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.utils.rendertools import RenderTool from flatland.utils.rendertools import RenderTool
from utils.observation_utils import normalize_observation from utils.observation_utils import normalize_observation
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
def main(argv): def main(argv):
try: try:
...@@ -37,26 +35,26 @@ def main(argv): ...@@ -37,26 +35,26 @@ def main(argv):
np.random.seed(1) np.random.seed(1)
# Parameters for the Environment # Parameters for the Environment
x_dim = 40 x_dim = 35
y_dim = 40 y_dim = 35
n_agents = 4 n_agents = 5
# Use a the malfunction generator to break agents from time to time # Use a the malfunction generator to break agents from time to time
stochastic_data = {'prop_malfunction': 0.05, # Percentage of defective agents stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents
'malfunction_rate': 50, # Rate of malfunction occurence 'malfunction_rate': 30, # Rate of malfunction occurence
'min_duration': 3, # Minimal duration of malfunction 'min_duration': 3, # Minimal duration of malfunction
'max_duration': 20 # Max duration of malfunction 'max_duration': 20 # Max duration of malfunction
} }
# Custom observation builder # Custom observation builder
TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
# Different agent types (trains) with different speeds. # Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0.25, # Fast passenger train speed_ration_map = {1.: 0., # Fast passenger train
1. / 2.: 0.25, # Fast freight train 1. / 2.: 1.0, # Fast freight train
1. / 3.: 0.25, # Slow commuter train 1. / 3.: 0.0, # Slow commuter train
1. / 4.: 0.25} # Slow freight train 1. / 4.: 0.0} # Slow freight train
env = RailEnv(width=x_dim, env = RailEnv(width=x_dim,
height=y_dim, height=y_dim,
...@@ -120,8 +118,9 @@ def main(argv): ...@@ -120,8 +118,9 @@ def main(argv):
env_renderer.reset() env_renderer.reset()
# Build agent specific observations # Build agent specific observations
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10) if obs[a]:
agent_obs_buffer[a] = agent_obs[a].copy() agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
agent_obs_buffer[a] = agent_obs[a].copy()
# Reset score and done # Reset score and done
score = 0 score = 0
...@@ -153,7 +152,8 @@ def main(argv): ...@@ -153,7 +152,8 @@ def main(argv):
agent_obs_buffer[a] = agent_obs[a].copy() agent_obs_buffer[a] = agent_obs[a].copy()
agent_action_buffer[a] = action_dict[a] agent_action_buffer[a] = action_dict[a]
agent_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10) if next_obs[a]:
agent_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10)
score += all_rewards[a] / env.get_num_agents() score += all_rewards[a] / env.get_num_agents()
...@@ -192,7 +192,7 @@ def main(argv): ...@@ -192,7 +192,7 @@ def main(argv):
100 * np.mean(done_window), 100 * np.mean(done_window),
eps, action_prob / np.sum(action_prob))) eps, action_prob / np.sum(action_prob)))
torch.save(agent.qnetwork_local.state_dict(), torch.save(agent.qnetwork_local.state_dict(),
'./Nets/avoider_checkpoint' + str(trials) + '.pth') './Nets/navigator_checkpoint' + str(trials) + '.pth')
action_prob = [1] * action_size action_prob = [1] * action_size
# Plot overall training progress at the end # Plot overall training progress at the end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment