Commit dc4b8bad authored by MasterScrat's avatar MasterScrat
Browse files

Cleanup, unique ids for training, slightly better checkpoint

parent f426762a
#checkpoints/*
!checkpoints/.gitkeep
replay_buffers/*
!replay_buffers/.gitkeep
......@@ -22,7 +22,7 @@ from flatland.utils.rendertools import RenderTool
base_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(base_dir))
from reinforcement_learning.timer import Timer
from utils.timer import Timer
from utils.observation_utils import normalize_observation
from reinforcement_learning.dddqn_policy import DDDQNPolicy
......@@ -162,7 +162,7 @@ def eval_policy(env_params, checkpoint, n_eval_episodes, max_steps, action_size,
for step in range(max_steps - 1):
if allow_skipping and check_if_all_blocked(env):
# why -1? bug where all agents are "done" after max_steps!
# FIXME why -1? bug where all agents are "done" after max_steps!
skipped = max_steps - step - 1
final_step = max_steps - 2
n_unfinished_agents = sum(not done[idx] for idx in env.get_agent_handles())
......@@ -176,12 +176,11 @@ def eval_policy(env_params, checkpoint, n_eval_episodes, max_steps, action_size,
agent_obs[agent] = normalize_observation(obs[agent], tree_depth=observation_tree_depth, observation_radius=observation_radius)
preproc_timer.end()
action = 0
if info['action_required'][agent]:
inference_timer.start()
action = policy.act(agent_obs[agent], eps=0.0)
inference_timer.end()
action_dict.update({agent: action})
if info['action_required'][agent]:
inference_timer.start()
action_dict.update({agent: policy.act(agent_obs[agent], eps=0.0)})
inference_timer.end()
agent_timer.end()
step_timer.start()
......@@ -261,60 +260,62 @@ def evaluate_agents(file, n_evaluation_episodes, use_gpu, render, allow_skipping
if total_nb_eval != n_evaluation_episodes:
print("(Rounding up from {} to fill all cores)".format(n_evaluation_episodes))
env_params_dict = {
# Observation parameters need to match the ones used during training!
# Test_0
test0_params = {
# sample configuration
"n_agents": 2,
"x_dim": 35,
"y_dim": 35,
"n_cities": 4,
"n_agents": 5,
"x_dim": 25,
"y_dim": 25,
"n_cities": 2,
"max_rails_between_cities": 2,
"max_rails_in_city": 3,
# observations
"observation_tree_depth": 1,
"observation_radius": 10,
"observation_max_path_depth": 20
}
# Test_1
test1_params = {
# environment
"n_agents": 10,
"x_dim": 30,
"y_dim": 30,
"n_cities": 2,
"max_rails_between_cities": 2,
"max_rails_in_city": 3,
"seed": 42,
"observation_tree_depth": 2,
# observations
"observation_tree_depth": 1,
"observation_radius": 10,
"observation_max_path_depth": 30
"observation_max_path_depth": 20
}
# env_params_dict = {
# # environment
# "n_agents": 15,
# "x_dim": 60,
# "y_dim": 60,
# "n_cities": 7,
# "max_rails_between_cities": 2,
# "max_rails_in_city": 4,
#
# # observations
# "observation_tree_depth": 2,
# "observation_radius": 10,
# "observation_max_path_depth": 30
# }
env_params = Namespace(**env_params_dict)
env_params = Namespace(**test0_params)
print("Environment parameters:")
pprint(env_params_dict)
pprint(test1_params)
# Calculate space dimensions and max steps
max_steps = int(4 * 2 * (env_params.x_dim + env_params.y_dim + (env_params.n_agents / env_params.n_cities)))
action_size = 5
tree_observation = TreeObsForRailEnv(max_depth=env_params.observation_tree_depth)
tree_depth = env_params.observation_tree_depth
num_features_per_node = tree_observation.observation_dim
tree_depth = 2
nr_nodes = 0
for i in range(tree_depth + 1):
nr_nodes += np.power(4, i)
state_size = num_features_per_node * nr_nodes
n_nodes = sum([np.power(4, i) for i in range(tree_depth + 1)])
state_size = num_features_per_node * n_nodes
results = []
if render:
results.append(eval_policy(env_params_dict, file, eval_per_thread, max_steps, action_size, state_size, 0, render, allow_skipping))
results.append(eval_policy(test1_params, file, eval_per_thread, max_steps, action_size, state_size, 0, render, allow_skipping))
else:
with Pool() as p:
results = p.starmap(eval_policy,
[(env_params_dict, file, 1, max_steps, action_size, state_size, seed * nb_threads, render, allow_skipping)
[(test1_params, file, 1, max_steps, action_size, state_size, seed * nb_threads, render, allow_skipping)
for seed in
range(total_nb_eval)])
......
......@@ -11,25 +11,21 @@ class DuelingQNetwork(nn.Module):
# value network
self.fc1_val = nn.Linear(state_size, hidsize1)
self.fc2_val = nn.Linear(hidsize1, hidsize2)
self.fc3_val = nn.Linear(hidsize1, hidsize2)
self.fc4_val = nn.Linear(hidsize2, 1)
# advantage network
self.fc1_adv = nn.Linear(state_size, hidsize1)
self.fc2_adv = nn.Linear(hidsize1, hidsize2)
self.fc3_adv = nn.Linear(hidsize1, hidsize2)
self.fc4_adv = nn.Linear(hidsize2, action_size)
def forward(self, x):
val = F.relu(self.fc1_val(x))
val = F.relu(self.fc2_val(val))
#val = F.relu(self.fc3_val(val))
val = self.fc4_val(val)
# advantage calculation
adv = F.relu(self.fc1_adv(x))
adv = F.relu(self.fc2_adv(adv))
#adv = F.relu(self.fc3_adv(adv))
adv = self.fc4_adv(adv)
return val + adv - adv.mean()
from datetime import datetime
import os
import random
import sys
......@@ -22,12 +23,13 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv
base_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(base_dir))
from utils.timer import Timer
from utils.observation_utils import normalize_observation
from reinforcement_learning.timer import Timer
from reinforcement_learning.dddqn_policy import DDDQNPolicy
try:
import wandb
wandb.init(sync_tensorboard=True)
except ImportError:
print("Install wandb to log to Weights & Biases")
......@@ -50,6 +52,10 @@ def train_agent(env_params, train_params):
max_rails_in_city = env_params.max_rails_in_city
seed = env_params.seed
# Unique ID for this training
now = datetime.now()
training_id = now.strftime('%y%m%d%H%M%S')
# Observation parameters
observation_tree_depth = env_params.observation_tree_depth
observation_radius = env_params.observation_radius
......@@ -64,8 +70,8 @@ def train_agent(env_params, train_params):
n_eval_episodes = train_params.n_evaluation_episodes
# TODO make command line parameters
replay_buffer_path = "replay_buffers/rb-100.pkl"
save_replay_buffer = True
replay_buffer_path = "replay_buffers/rb-1500.pkl"
save_replay_buffer = False
# Set the seeds
random.seed(seed)
......@@ -149,8 +155,8 @@ def train_agent(env_params, train_params):
print("\n💾 Replay buffer status: {}/{} experiences".format(len(policy.memory.memory), train_params.buffer_size))
hdd = psutil.disk_usage('/')
if save_replay_buffer and (hdd.free / (2**30)) < 500.0:
print("⚠️ Careful! Saving replay buffers will quickly consume a lot of disk space. You have {:.2f}gb left.".format(hdd.free / (2**30)))
if save_replay_buffer and (hdd.free / (2 ** 30)) < 500.0:
print("⚠️ Careful! Saving replay buffers will quickly consume a lot of disk space. You have {:.2f}gb left.".format(hdd.free / (2 ** 30)))
# TensorBoard writer
writer = SummaryWriter()
......@@ -160,8 +166,14 @@ def train_agent(env_params, train_params):
training_timer = Timer()
training_timer.start()
print("\n🚉 Training {} trains on {}x{} grid for {} episodes, evaluating on {} episodes every {} episodes.\n"
.format(env.get_num_agents(), x_dim, y_dim, n_episodes, n_eval_episodes, checkpoint_interval))
print("\n🚉 Training {} trains on {}x{} grid for {} episodes, evaluating on {} episodes every {} episodes. Training id '{}'.\n".format(
env.get_num_agents(),
x_dim, y_dim,
n_episodes,
n_eval_episodes,
checkpoint_interval,
training_id
))
for episode_idx in range(n_episodes + 1):
step_timer = Timer()
......@@ -260,8 +272,10 @@ def train_agent(env_params, train_params):
# Print logs
if episode_idx % checkpoint_interval == 0:
torch.save(policy.qnetwork_local, './checkpoints/multi-' + str(episode_idx) + '.pth')
policy.save_replay_buffer('./replay_buffers/rb-' + str(episode_idx) + '.pkl')
torch.save(policy.qnetwork_local, './checkpoints/' + training_id + '-' + str(episode_idx) + '.pth')
if save_replay_buffer:
policy.save_replay_buffer('./replay_buffers/rb-' + training_id + '-' + str(episode_idx) + '.pkl')
if train_params.render:
env_renderer.close_window()
......@@ -366,7 +380,7 @@ def eval_policy(env, policy, n_eval_episodes, max_steps):
for agent in env.get_agent_handles():
if obs[agent]:
# TODO pass parameters properly
agent_obs[agent] = normalize_observation(obs[agent], tree_depth=2, observation_radius=10)
agent_obs[agent] = normalize_observation(obs[agent], tree_depth=1, observation_radius=10)
action = 0
if info['action_required'][agent]:
......@@ -423,7 +437,7 @@ if __name__ == "__main__":
environment_parameters = {
# small_v0 config
"n_agents": 2,
"n_agents": 5,
"x_dim": 35,
"y_dim": 35,
"n_cities": 4,
......@@ -431,9 +445,9 @@ if __name__ == "__main__":
"max_rails_in_city": 3,
"seed": 42,
"observation_tree_depth": 2,
"observation_tree_depth": 1,
"observation_radius": 10,
"observation_max_path_depth": 30
"observation_max_path_depth": 20
}
print("\nEnvironment parameters:")
......
This diff is collapsed.
import sys
from argparse import Namespace
from multiprocessing.pool import Pool
from pathlib import Path
import numpy as np
......@@ -8,10 +7,10 @@ import time
import torch
from flatland.core.env_observation_builder import DummyObservationBuilder
from flatland.envs.observations import TreeObsForRailEnv
from flatland.evaluators.client import FlatlandRemoteClient
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from reinforcement_learning.observations import TreeObsForRailEnv
from utils.deadlock_check import check_if_all_blocked
base_dir = Path(__file__).resolve().parent.parent
......@@ -20,22 +19,36 @@ sys.path.append(str(base_dir))
from reinforcement_learning.dddqn_policy import DDDQNPolicy
from utils.observation_utils import normalize_observation
##########################
# EVALUATION PARAMETERS #
# TODO check if the parallel version works properly
#
# def infer_action(obs):
# norm_obs = normalize_observation(obs, tree_depth=observation_tree_depth, observation_radius=observation_radius)
# return policy.act(norm_obs, eps=0.0)
#
#
# def rl_controller(obs, number_of_agents):
# obs_list = []
# for agent in range(number_of_agents):
# if obs[agent] and info['action_required'][agent]:
# obs_list.append(obs[agent])
#
# return dict(zip(range(number_of_agents), pool.map(infer_action, obs_list)))
# Checkpoint to use
checkpoint = "checkpoints/sample-checkpoint.pth"
# Observation parameters
# These need to match your training parameters!
observation_tree_depth = 1
observation_radius = 10
observation_max_path_depth = 20
remote_client = FlatlandRemoteClient()
# Observation builder
predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth)
tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth, predictor=predictor)
# Calculates state and action sizes
num_features_per_node = tree_observation.observation_dim
n_nodes = sum([np.power(4, i) for i in range(observation_tree_depth + 1)])
state_size = num_features_per_node * n_nodes
action_size = 5
# Creates the policy. No GPU on evaluation server.
policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True)
policy.qnetwork_local = torch.load(checkpoint)
# Controller that uses the RL policy
def rl_controller(obs, number_of_agents):
action_dict = {}
for agent in range(number_of_agents):
......@@ -47,166 +60,114 @@ def rl_controller(obs, number_of_agents):
return action_dict
if __name__ == "__main__":
#####################################################################
# Instantiate a Remote Client
#####################################################################
#####################################################################
# Main evaluation loop
#####################################################################
evaluation_number = 0
remote_client = FlatlandRemoteClient()
while True:
evaluation_number += 1
# Checkpoint to use
checkpoint = "checkpoints/multi-300.pth"
# We use a dummy observation and call TreeObsForRailEnv ourselves.
# This way we decide if we want to calculate the observations or not,
# instead of having them calculated every time we perform an env step.
time_start = time.time()
observation, info = remote_client.env_create(
obs_builder_object=DummyObservationBuilder()
)
env_creation_time = time.time() - time_start
# Evaluation is faster on CPU (except if you use a really huge)
parameters = {
'use_gpu': False
}
if not observation:
# If the remote_client returns False on a `env_create` call,
# then it basically means that your agent has already been
# evaluated on all the required evaluation environments,
# and hence its safe to break out of the main evaluation loop.
break
# Observation parameters
observation_tree_depth = 1
observation_radius = 10
observation_max_path_depth = 20
local_env = remote_client.env
nb_agents = len(local_env.agents)
max_nb_steps = local_env._max_episode_steps
# Observation builder
predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth)
tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth, predictor=predictor)
tree_observation.set_env(local_env)
tree_observation.reset()
observation = tree_observation.get_many(list(range(nb_agents)))
num_features_per_node = tree_observation.observation_dim
nr_nodes = 0
for i in range(observation_tree_depth + 1):
nr_nodes += np.power(4, i)
state_size = num_features_per_node * nr_nodes
action_size = 5
print("Evaluation {}: {} agents in {}x{}".format(evaluation_number, nb_agents, local_env.width, local_env.height))
policy = DDDQNPolicy(state_size, action_size, Namespace(**parameters), evaluation_mode=True)
policy.qnetwork_local = torch.load(checkpoint)
# Now we enter into another infinite loop where we
# compute the actions for all the individual steps in this episode
# until the episode is `done`
time_taken_by_controller = []
time_taken_per_step = []
steps = 0
# Controller
pool = Pool()
#####################################################################
# Main evaluation loop
#
# This iterates over an arbitrary number of env evaluations
#####################################################################
evaluation_number = 0
while True:
evaluation_number += 1
time_start = time.time()
observation, info = remote_client.env_create(
obs_builder_object=DummyObservationBuilder()
)
env_creation_time = time.time() - time_start
if not observation:
# If the remote_client returns False on a `env_create` call,
# then it basically means that your agent has already been
# evaluated on all the required evaluation environments,
# and hence its safe to break out of the main evaluation loop
break
local_env = remote_client.env
nb_agents = len(local_env.agents)
tree_observation.set_env(local_env)
tree_observation.reset()
tree_observation.prepare_get_many(list(range(nb_agents)))
observation = tree_observation.get_many(list(range(nb_agents)))
nb_cities = 2 # FIXME get true value
max_nb_steps = int(4 * 2 * (local_env.width + local_env.height + (nb_agents / nb_cities)))
print("Evaluation {}: {} agents in {}x{}".format(evaluation_number, nb_agents, local_env.width, local_env.height))
# Now we enter into another infinite loop where we
# compute the actions for all the individual steps in this episode
# until the episode is `done`
time_taken_by_controller = []
time_taken_per_step = []
steps = 0
while True:
prepare_time, obs_time, agent_time, step_time = 0, 0, 0, 0
#####################################################################
# Evaluation of a single episode
#####################################################################
steps += 1
if nb_agents < 100:
time_start = time.time()
action = rl_controller(observation, nb_agents)
agent_time = time.time() - time_start
time_taken_by_controller.append(agent_time)
time_start = time.time()
_, all_rewards, done, info = remote_client.env_step(action)
step_time = time.time() - time_start
time_taken_per_step.append(step_time)
time_start = time.time()
tree_observation.prepare_get_many(list(range(nb_agents)))
prepare_time = time.time() - time_start
time_start = time.time()
observation_list = []
for h in range(nb_agents):
observation_list.append(tree_observation.get(h))
observation = dict(zip(range(nb_agents), observation_list))
obs_time = time.time() - time_start
else:
# too many agents: just wait for it to pass 🏃💨
time_start = time.time()
_, all_rewards, done, info = remote_client.env_step({})
step_time = time.time() - time_start
time_taken_per_step.append(step_time)
nb_agents_done = sum(done[idx] for idx in local_env.get_agent_handles())
print("Step {}/{}\tAgents done: {}\t Predictor time {:.3f}s\t Obs time {:.3f}s\t Inference time {:.3f}s\t Step time {:.3f}s".format(
#####################################################################
# Evaluation of a single episode
#####################################################################
steps += 1
obs_time, agent_time, step_time = 0.0, 0.0, 0.0
if nb_agents < 100 and not check_if_all_blocked(env=local_env):
time_start = time.time()
action = rl_controller(observation, nb_agents)
agent_time = time.time() - time_start
time_taken_by_controller.append(agent_time)
time_start = time.time()
_, all_rewards, done, info = remote_client.env_step(action)
step_time = time.time() - time_start
time_taken_per_step.append(step_time)
time_start = time.time()
observation_list = tree_observation.get_many(list(range(nb_agents)))
obs_time = time.time() - time_start
else:
# Too many agents or fully deadlocked: no-op to finish the episode ASAP 🏃💨
time_start = time.time()
_, all_rewards, done, info = remote_client.env_step({})
step_time = time.time() - time_start
time_taken_per_step.append(step_time)
nb_agents_done = sum(done[idx] for idx in local_env.get_agent_handles())
if not done['__all__']:
print("Step {}/{}\tAgents done: {}\t Obs time {:.3f}s\t Inference time {:.3f}s\t Step time {:.3f}s".format(
str(steps).zfill(4),
max_nb_steps,
nb_agents_done,
prepare_time,
obs_time,
agent_time,
step_time
), end="" if done['__all__'] else "\r")
# if check_if_all_blocked(local_env):
# print("DEADLOCKED!!")
# print(evaluation_number, steps, done['__all__'])
if done['__all__']:
print()
print("Reward : ", sum(list(all_rewards.values())))
#
# When done['__all__'] == True, then the evaluation of this
# particular Env instantiation is complete, and we can break out
# of this loop, and move onto the next Env evaluation
break
np_time_taken_by_controller = np.array(time_taken_by_controller)
np_time_taken_per_step = np.array(time_taken_per_step)
print("=" * 100)
print("=" * 100)
print("Evaluation Number : ", evaluation_number)
print("Current Env Path : ", remote_client.current_env_path)
print("Env Creation Time : ", env_creation_time)
print("Number of Steps : ", steps)
print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(), np_time_taken_by_controller.std())
print("Mean/Std of Time per Step : ", np_time_taken_per_step.mean(), np_time_taken_per_step.std())
print("=" * 100)
print("Evaluation of all environments complete...")
########################################################################
# Submit your Results
#
# Please do not forget to include this call, as this triggers the
# final computation of the score statistics, video generation, etc
# and is necessary to have your submission marked as successfully evaluated
########################################################################
print(remote_client.submit())
), end="\r")
else:
print()
print("Reward : ", sum(list(all_rewards.values())))
#
# When done['__all__'] == True, then the evaluation of this
# particular Env instantiation is complete, and we can break out
# of this loop, and move onto the next Env evaluation
break
np_time_taken_by_controller = np.array(time_taken_by_controller)
np_time_taken_per_step = np.array(time_taken_per_step)
print("=" * 100)
print("=" * 100)
print("Evaluation Number : ", evaluation_number)
print("Current Env Path : ", remote_client.current_env_path)
print("Env Creation Time : ", env_creation_time)
print("Number of Steps : ", steps)
print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(), np_time_taken_by_controller.std())
print("Mean/Std of Time per Step : ", np_time_taken_per_step.mean(), np_time_taken_per_step.std())
print("=" * 100)
print("Evaluation of all environments complete...")
########################################################################
# Submit your Results
#
# Please do not forget to include this call, as this triggers the
# final computation of the score statistics, video generation, etc
# and is necessary to have your submission marked as successfully evaluated
########################################################################
print(remote_client.submit())
......@@ -6,7 +6,6 @@ def check_if_all_blocked(env):
"""
Checks whether all the agents are blocked (full deadlock situation).
In that case it is pointless to keep running inference as no agent will be able to move.
FIXME still experimental!
:param env: current environment
:return:
"""
......@@ -40,4 +39,4 @@ def check_if_all_blocked(env):
return False