Skip to content
Snippets Groups Projects
Commit f3194eb3 authored by Erik Nygren's avatar Erik Nygren
Browse files

updated training file to max time and added multi speed and stochasticity

parent 68b09076
No related branches found
No related tags found
No related merge requests found
...@@ -43,7 +43,7 @@ stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents ...@@ -43,7 +43,7 @@ stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents
} }
# Custom observation builder # Custom observation builder
TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
# Different agent types (trains) with different speeds. # Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0.25, # Fast passenger train speed_ration_map = {1.: 0.25, # Fast passenger train
...@@ -79,7 +79,7 @@ action_size = 5 ...@@ -79,7 +79,7 @@ action_size = 5
# We set the number of episodes we would like to train on # We set the number of episodes we would like to train on
if 'n_trials' not in locals(): if 'n_trials' not in locals():
n_trials = 60000 n_trials = 60000
max_steps = int(3 * (env.height + env.width)) max_steps = int(4 * 2 * (20 + env.height + env.width))
eps = 1. eps = 1.
eps_end = 0.005 eps_end = 0.005
eps_decay = 0.9995 eps_decay = 0.9995
...@@ -93,7 +93,7 @@ action_prob = [0] * action_size ...@@ -93,7 +93,7 @@ action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents() agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size) agent = Agent(state_size, action_size)
with path(torch_training.Nets, "avoider_checkpoint100.pth") as file_in: with path(torch_training.Nets, "navigator_checkpoint1200.pth") as file_in:
agent.qnetwork_local.load_state_dict(torch.load(file_in)) agent.qnetwork_local.load_state_dict(torch.load(file_in))
record_images = False record_images = False
...@@ -118,7 +118,6 @@ for trials in range(1, n_trials + 1): ...@@ -118,7 +118,6 @@ for trials in range(1, n_trials + 1):
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
if info['action_required'][a]: if info['action_required'][a]:
action = agent.act(agent_obs[a], eps=0.) action = agent.act(agent_obs[a], eps=0.)
else: else:
action = 0 action = 0
...@@ -129,7 +128,8 @@ for trials in range(1, n_trials + 1): ...@@ -129,7 +128,8 @@ for trials in range(1, n_trials + 1):
env_renderer.render_env(show=True, show_predictions=True, show_observations=False) env_renderer.render_env(show=True, show_predictions=True, show_observations=False)
# Build agent specific observations and normalize # Build agent specific observations and normalize
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10) if obs[a]:
agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
if done['__all__']: if done['__all__']:
......
...@@ -20,6 +20,7 @@ from flatland.utils.rendertools import RenderTool ...@@ -20,6 +20,7 @@ from flatland.utils.rendertools import RenderTool
from utils.observation_utils import normalize_observation from utils.observation_utils import normalize_observation
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.agent_utils import RailAgentStatus
def main(argv): def main(argv):
try: try:
...@@ -37,24 +38,24 @@ def main(argv): ...@@ -37,24 +38,24 @@ def main(argv):
# Parameters for the Environment # Parameters for the Environment
x_dim = 35 x_dim = 35
y_dim = 35 y_dim = 35
n_agents = 5 n_agents = 10
# Use a the malfunction generator to break agents from time to time # Use a the malfunction generator to break agents from time to time
stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents stochastic_data = {'prop_malfunction': 0.05, # Percentage of defective agents
'malfunction_rate': 30, # Rate of malfunction occurence 'malfunction_rate': 100, # Rate of malfunction occurence
'min_duration': 3, # Minimal duration of malfunction 'min_duration': 20, # Minimal duration of malfunction
'max_duration': 20 # Max duration of malfunction 'max_duration': 50 # Max duration of malfunction
} }
# Custom observation builder # Custom observation builder
TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30)) TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
# Different agent types (trains) with different speeds. # Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0., # Fast passenger train speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 1.0, # Fast freight train 1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.0, # Slow commuter train 1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.0} # Slow freight train 1. / 4.: 0.25} # Slow freight train
env = RailEnv(width=x_dim, env = RailEnv(width=x_dim,
height=y_dim, height=y_dim,
...@@ -87,7 +88,7 @@ def main(argv): ...@@ -87,7 +88,7 @@ def main(argv):
n_trials = 15000 n_trials = 15000
# And the max number of steps we want to take per episode # And the max number of steps we want to take per episode
max_steps = int(3 * (env.height + env.width)) max_steps = int(4 * 2 * (20 + env.height + env.width))
# Define training parameters # Define training parameters
eps = 1. eps = 1.
...@@ -107,7 +108,7 @@ def main(argv): ...@@ -107,7 +108,7 @@ def main(argv):
agent_obs_buffer = [None] * env.get_num_agents() agent_obs_buffer = [None] * env.get_num_agents()
agent_action_buffer = [2] * env.get_num_agents() agent_action_buffer = [2] * env.get_num_agents()
cummulated_reward = np.zeros(env.get_num_agents()) cummulated_reward = np.zeros(env.get_num_agents())
update_values = False update_values = [False] * env.get_num_agents()
# Now we load a Double dueling DQN agent # Now we load a Double dueling DQN agent
agent = Agent(state_size, action_size) agent = Agent(state_size, action_size)
...@@ -127,16 +128,16 @@ def main(argv): ...@@ -127,16 +128,16 @@ def main(argv):
env_done = 0 env_done = 0
# Run episode # Run episode
for step in range(max_steps): while True:
# Action # Action
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
if info['action_required'][a]: if info['action_required'][a]:
# If an action is require, we want to store the obs a that step as well as the action # If an action is require, we want to store the obs a that step as well as the action
update_values = True update_values[a] = True
action = agent.act(agent_obs[a], eps=eps) action = agent.act(agent_obs[a], eps=eps)
action_prob[action] += 1 action_prob[action] += 1
else: else:
update_values = False update_values[a] = False
action = 0 action = 0
action_dict.update({a: action}) action_dict.update({a: action})
...@@ -145,7 +146,7 @@ def main(argv): ...@@ -145,7 +146,7 @@ def main(argv):
# Update replay buffer and train agent # Update replay buffer and train agent
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
# Only update the values when we are done or when an action was taken and thus relevant information is present # Only update the values when we are done or when an action was taken and thus relevant information is present
if update_values or done[a]: if update_values[a] or done[a]:
agent.step(agent_obs_buffer[a], agent_action_buffer[a], all_rewards[a], agent.step(agent_obs_buffer[a], agent_action_buffer[a], all_rewards[a],
agent_obs[a], done[a]) agent_obs[a], done[a])
cummulated_reward[a] = 0. cummulated_reward[a] = 0.
...@@ -167,8 +168,8 @@ def main(argv): ...@@ -167,8 +168,8 @@ def main(argv):
# Collection information about training # Collection information about training
tasks_finished = 0 tasks_finished = 0
for _idx in range(env.get_num_agents()): for agent in env.agents:
if done[_idx] == 1: if agent.status == RailAgentStatus.DONE_REMOVED:
tasks_finished += 1 tasks_finished += 1
done_window.append(tasks_finished / max(1, env.get_num_agents())) done_window.append(tasks_finished / max(1, env.get_num_agents()))
scores_window.append(score / max_steps) # save most recent score scores_window.append(score / max_steps) # save most recent score
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment