From a9440cfba443a45b3a309c910187bd58469f56ea Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Tue, 23 Apr 2019 08:42:20 +0200 Subject: [PATCH] Updated training protocoll and agent for navigation task --- examples/training_navigation.py | 83 ++++++++++++++++++++---- flatland/baselines/dueling_double_dqn.py | 2 +- 2 files changed, 70 insertions(+), 15 deletions(-) diff --git a/examples/training_navigation.py b/examples/training_navigation.py index 33fe287d..3797d1cf 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -1,7 +1,10 @@ from flatland.envs.rail_env import * from flatland.core.env_observation_builder import TreeObsForRailEnv from flatland.utils.rendertools import * -from flatland.agents.dueling_double_dqn import Agent +from flatland.baselines.dueling_double_dqn import Agent +from collections import deque +import torch + random.seed(1) np.random.seed(1) @@ -36,6 +39,16 @@ handle = env.get_agent_handles() state_size = 105 action_size = 4 +n_trials = 5000 +eps = 1. +eps_end = 0.005 +eps_decay = 0.998 +action_dict = dict() +scores_window = deque(maxlen=100) +done_window = deque(maxlen=100) +scores = [] +dones_list = [] + agent = Agent(state_size, action_size, "FC", 0) # Example generate a rail given a manual specification, @@ -49,27 +62,69 @@ env = RailEnv(width=6, number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2)) - env.agents_position[0] = [1, 4] env.agents_target[0] = [1, 1] env.agents_direction[0] = 1 # TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too! env.obs_builder.reset() -# TODO: delete next line -#for i in range(4): -# print(env.obs_builder.distance_map[0, :, :, i]) -obs, all_rewards, done, _ = env.step({0:0}) -#env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) +for trials in range(1, n_trials + 1): -env_renderer = RenderTool(env) -action_dict = {0: 0} + # Reset environment + obs, all_rewards, done, _ = env.step({0: 0}) + # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) + + + score = 0 + env_done = 0 + + # Run episode + for step in range(100): + + # Action + for a in range(env.number_of_agents): + action = agent.act(np.array(obs[a]), eps=eps) + action_dict.update({a: action}) + + # Environment step + next_obs, all_rewards, done, _ = env.step(action_dict) + + + + # Update replay buffer and train agent + for a in range(env.number_of_agents): + agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]) + score += all_rewards[a] + + obs = next_obs.copy() -for step in range(100): - obs, all_rewards, done, _ = env.step(action_dict) - action = agent.act(np.array(obs[0]),eps=1) + if all(done): + env_done = 1 + break + # Epsioln decay + eps = max(eps_end, eps_decay * eps) # decrease epsilon - action_dict = {0 :action} - print("Rewards: ", all_rewards, " [done=", done, "]") + done_window.append(env_done) + scores_window.append(score) # save most recent score + scores.append(np.mean(scores_window)) + dones_list.append((np.mean(done_window))) + print('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f}'.format(env.number_of_agents, + trials, + np.mean( + scores_window), + 100 * np.mean( + done_window), + eps), + end=" ") + if trials % 100 == 0: + print( + '\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f}'.format(env.number_of_agents, + trials, + np.mean( + scores_window), + 100 * np.mean( + done_window), + eps)) + torch.save(agent.qnetwork_local.state_dict(), '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth') diff --git a/flatland/baselines/dueling_double_dqn.py b/flatland/baselines/dueling_double_dqn.py index 3eacf4c9..084d0a22 100644 --- a/flatland/baselines/dueling_double_dqn.py +++ b/flatland/baselines/dueling_double_dqn.py @@ -2,7 +2,7 @@ import numpy as np import random from collections import namedtuple, deque import os -from flatland.agents.model import QNetwork, QNetwork2 +from flatland.baselines.model import QNetwork, QNetwork2 import torch import torch.nn.functional as F import torch.optim as optim -- GitLab