Skip to content
Snippets Groups Projects
Commit a9440cfb authored by Erik Nygren's avatar Erik Nygren
Browse files

Updated training protocoll and agent for navigation task

parent 360ad3ec
No related branches found
No related tags found
No related merge requests found
from flatland.envs.rail_env import * from flatland.envs.rail_env import *
from flatland.core.env_observation_builder import TreeObsForRailEnv from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.utils.rendertools import * from flatland.utils.rendertools import *
from flatland.agents.dueling_double_dqn import Agent from flatland.baselines.dueling_double_dqn import Agent
from collections import deque
import torch
random.seed(1) random.seed(1)
np.random.seed(1) np.random.seed(1)
...@@ -36,6 +39,16 @@ handle = env.get_agent_handles() ...@@ -36,6 +39,16 @@ handle = env.get_agent_handles()
state_size = 105 state_size = 105
action_size = 4 action_size = 4
n_trials = 5000
eps = 1.
eps_end = 0.005
eps_decay = 0.998
action_dict = dict()
scores_window = deque(maxlen=100)
done_window = deque(maxlen=100)
scores = []
dones_list = []
agent = Agent(state_size, action_size, "FC", 0) agent = Agent(state_size, action_size, "FC", 0)
# Example generate a rail given a manual specification, # Example generate a rail given a manual specification,
...@@ -49,27 +62,69 @@ env = RailEnv(width=6, ...@@ -49,27 +62,69 @@ env = RailEnv(width=6,
number_of_agents=1, number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2)) obs_builder_object=TreeObsForRailEnv(max_depth=2))
env.agents_position[0] = [1, 4] env.agents_position[0] = [1, 4]
env.agents_target[0] = [1, 1] env.agents_target[0] = [1, 1]
env.agents_direction[0] = 1 env.agents_direction[0] = 1
# TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too! # TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too!
env.obs_builder.reset() env.obs_builder.reset()
# TODO: delete next line
#for i in range(4):
# print(env.obs_builder.distance_map[0, :, :, i])
obs, all_rewards, done, _ = env.step({0:0}) for trials in range(1, n_trials + 1):
#env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
env_renderer = RenderTool(env) # Reset environment
action_dict = {0: 0} obs, all_rewards, done, _ = env.step({0: 0})
# env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
score = 0
env_done = 0
# Run episode
for step in range(100):
# Action
for a in range(env.number_of_agents):
action = agent.act(np.array(obs[a]), eps=eps)
action_dict.update({a: action})
# Environment step
next_obs, all_rewards, done, _ = env.step(action_dict)
# Update replay buffer and train agent
for a in range(env.number_of_agents):
agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
score += all_rewards[a]
obs = next_obs.copy()
for step in range(100): if all(done):
obs, all_rewards, done, _ = env.step(action_dict) env_done = 1
action = agent.act(np.array(obs[0]),eps=1) break
# Epsioln decay
eps = max(eps_end, eps_decay * eps) # decrease epsilon
action_dict = {0 :action} done_window.append(env_done)
print("Rewards: ", all_rewards, " [done=", done, "]") scores_window.append(score) # save most recent score
scores.append(np.mean(scores_window))
dones_list.append((np.mean(done_window)))
print('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f}'.format(env.number_of_agents,
trials,
np.mean(
scores_window),
100 * np.mean(
done_window),
eps),
end=" ")
if trials % 100 == 0:
print(
'\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f}'.format(env.number_of_agents,
trials,
np.mean(
scores_window),
100 * np.mean(
done_window),
eps))
torch.save(agent.qnetwork_local.state_dict(), '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth')
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
import random import random
from collections import namedtuple, deque from collections import namedtuple, deque
import os import os
from flatland.agents.model import QNetwork, QNetwork2 from flatland.baselines.model import QNetwork, QNetwork2
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment