Skip to content
Snippets Groups Projects
Commit f159c489 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Merge branch 'master' of gitlab.aicrowd.com:flatland/baselines

# Conflicts:
#	torch_training/multi_agent_training.py
parents e1ce4008 e3ead28c
No related branches found
No related tags found
No related merge requests found
...@@ -44,7 +44,7 @@ stochastic_data = {'malfunction_rate': 8000, # Rate of malfunction occurence of ...@@ -44,7 +44,7 @@ stochastic_data = {'malfunction_rate': 8000, # Rate of malfunction occurence of
# Custom observation builder # Custom observation builder
TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
# Different agent types (trains) with different speeds. # Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0.25, # Fast passenger train speed_ration_map = {1.: 0.25, # Fast passenger train
...@@ -80,7 +80,7 @@ action_size = 5 ...@@ -80,7 +80,7 @@ action_size = 5
# We set the number of episodes we would like to train on # We set the number of episodes we would like to train on
if 'n_trials' not in locals(): if 'n_trials' not in locals():
n_trials = 60000 n_trials = 60000
max_steps = int(3 * (env.height + env.width)) max_steps = int(4 * 2 * (20 + env.height + env.width))
eps = 1. eps = 1.
eps_end = 0.005 eps_end = 0.005
eps_decay = 0.9995 eps_decay = 0.9995
...@@ -94,7 +94,7 @@ action_prob = [0] * action_size ...@@ -94,7 +94,7 @@ action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents() agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size) agent = Agent(state_size, action_size)
with path(torch_training.Nets, "avoider_checkpoint100.pth") as file_in: with path(torch_training.Nets, "navigator_checkpoint1200.pth") as file_in:
agent.qnetwork_local.load_state_dict(torch.load(file_in)) agent.qnetwork_local.load_state_dict(torch.load(file_in))
record_images = False record_images = False
...@@ -119,7 +119,6 @@ for trials in range(1, n_trials + 1): ...@@ -119,7 +119,6 @@ for trials in range(1, n_trials + 1):
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
if info['action_required'][a]: if info['action_required'][a]:
action = agent.act(agent_obs[a], eps=0.) action = agent.act(agent_obs[a], eps=0.)
else: else:
action = 0 action = 0
...@@ -130,7 +129,8 @@ for trials in range(1, n_trials + 1): ...@@ -130,7 +129,8 @@ for trials in range(1, n_trials + 1):
env_renderer.render_env(show=True, show_predictions=True, show_observations=False) env_renderer.render_env(show=True, show_predictions=True, show_observations=False)
# Build agent specific observations and normalize # Build agent specific observations and normalize
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10) if obs[a]:
agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
if done['__all__']: if done['__all__']:
......
...@@ -22,6 +22,7 @@ from flatland.utils.rendertools import RenderTool ...@@ -22,6 +22,7 @@ from flatland.utils.rendertools import RenderTool
from utils.observation_utils import normalize_observation from utils.observation_utils import normalize_observation
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.agent_utils import RailAgentStatus
def main(argv): def main(argv):
try: try:
...@@ -39,7 +40,7 @@ def main(argv): ...@@ -39,7 +40,7 @@ def main(argv):
# Parameters for the Environment # Parameters for the Environment
x_dim = 35 x_dim = 35
y_dim = 35 y_dim = 35
n_agents = 5 n_agents = 10
# Use a the malfunction generator to break agents from time to time # Use a the malfunction generator to break agents from time to time
...@@ -52,10 +53,10 @@ def main(argv): ...@@ -52,10 +53,10 @@ def main(argv):
TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30)) TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
# Different agent types (trains) with different speeds. # Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0., # Fast passenger train speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 1.0, # Fast freight train 1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.0, # Slow commuter train 1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.0} # Slow freight train 1. / 4.: 0.25} # Slow freight train
env = RailEnv(width=x_dim, env = RailEnv(width=x_dim,
height=y_dim, height=y_dim,
...@@ -88,7 +89,7 @@ def main(argv): ...@@ -88,7 +89,7 @@ def main(argv):
n_trials = 15000 n_trials = 15000
# And the max number of steps we want to take per episode # And the max number of steps we want to take per episode
max_steps = int(3 * (env.height + env.width)) max_steps = int(4 * 2 * (20 + env.height + env.width))
# Define training parameters # Define training parameters
eps = 1. eps = 1.
...@@ -108,7 +109,7 @@ def main(argv): ...@@ -108,7 +109,7 @@ def main(argv):
agent_obs_buffer = [None] * env.get_num_agents() agent_obs_buffer = [None] * env.get_num_agents()
agent_action_buffer = [2] * env.get_num_agents() agent_action_buffer = [2] * env.get_num_agents()
cummulated_reward = np.zeros(env.get_num_agents()) cummulated_reward = np.zeros(env.get_num_agents())
update_values = False update_values = [False] * env.get_num_agents()
# Now we load a Double dueling DQN agent # Now we load a Double dueling DQN agent
agent = Agent(state_size, action_size) agent = Agent(state_size, action_size)
...@@ -128,16 +129,16 @@ def main(argv): ...@@ -128,16 +129,16 @@ def main(argv):
env_done = 0 env_done = 0
# Run episode # Run episode
for step in range(max_steps): while True:
# Action # Action
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
if info['action_required'][a]: if info['action_required'][a]:
# If an action is require, we want to store the obs a that step as well as the action # If an action is require, we want to store the obs a that step as well as the action
update_values = True update_values[a] = True
action = agent.act(agent_obs[a], eps=eps) action = agent.act(agent_obs[a], eps=eps)
action_prob[action] += 1 action_prob[action] += 1
else: else:
update_values = False update_values[a] = False
action = 0 action = 0
action_dict.update({a: action}) action_dict.update({a: action})
...@@ -146,7 +147,7 @@ def main(argv): ...@@ -146,7 +147,7 @@ def main(argv):
# Update replay buffer and train agent # Update replay buffer and train agent
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
# Only update the values when we are done or when an action was taken and thus relevant information is present # Only update the values when we are done or when an action was taken and thus relevant information is present
if update_values or done[a]: if update_values[a] or done[a]:
agent.step(agent_obs_buffer[a], agent_action_buffer[a], all_rewards[a], agent.step(agent_obs_buffer[a], agent_action_buffer[a], all_rewards[a],
agent_obs[a], done[a]) agent_obs[a], done[a])
cummulated_reward[a] = 0. cummulated_reward[a] = 0.
...@@ -168,8 +169,8 @@ def main(argv): ...@@ -168,8 +169,8 @@ def main(argv):
# Collection information about training # Collection information about training
tasks_finished = 0 tasks_finished = 0
for _idx in range(env.get_num_agents()): for current_agent in env.agents:
if done[_idx] == 1: if current_agent.status == RailAgentStatus.DONE_REMOVED:
tasks_finished += 1 tasks_finished += 1
done_window.append(tasks_finished / max(1, env.get_num_agents())) done_window.append(tasks_finished / max(1, env.get_num_agents()))
scores_window.append(score / max_steps) # save most recent score scores_window.append(score / max_steps) # save most recent score
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment