Commit 0d6a9e54 authored by MasterScrat's avatar MasterScrat
Browse files

Better checkpoint (uses tree depth 2), adding action cache, cleanup

parent 85067914
......@@ -7,7 +7,6 @@ dependencies:
- tk=8.6.8
- cairo=1.16.0
- cairocffi=1.1.0
- cairosvg=2.4.2
- cffi=1.12.3
- cssselect2=0.2.1
- defusedxml=0.6.0
......
......@@ -9,8 +9,6 @@ from pprint import pprint
import numpy as np
import torch
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
......@@ -22,53 +20,13 @@ from flatland.utils.rendertools import RenderTool
base_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(base_dir))
from utils.deadlock_check import check_if_all_blocked
from utils.timer import Timer
from utils.observation_utils import normalize_observation
from reinforcement_learning.dddqn_policy import DDDQNPolicy
def check_if_all_blocked(env):
"""
Checks whether all the agents are blocked (full deadlock situation).
In that case it is pointless to keep running inference as no agent will be able to move.
FIXME still experimental!
:param env: current environment
:return:
"""
# First build a map of agents in each position
location_has_agent = {}
for agent in env.agents:
if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and agent.position:
location_has_agent[tuple(agent.position)] = 1
# Looks for any agent that can still move
for handle in env.get_agent_handles():
agent = env.agents[handle]
if agent.status == RailAgentStatus.READY_TO_DEPART:
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
agent_virtual_position = agent.position
elif agent.status == RailAgentStatus.DONE:
agent_virtual_position = agent.target
else:
continue
possible_transitions = env.rail.get_transitions(*agent_virtual_position, agent.direction)
orientation = agent.direction
for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
if possible_transitions[branch_direction]:
new_position = get_new_position(agent_virtual_position, branch_direction)
if new_position not in location_has_agent:
return False
# No agent can move at all: full deadlock!
return True
def eval_policy(env_params, checkpoint, n_eval_episodes, max_steps, action_size, state_size, seed, render, allow_skipping):
def eval_policy(env_params, checkpoint, n_eval_episodes, max_steps, action_size, state_size, seed, render, allow_skipping, allow_caching):
# Evaluation is faster on CPU (except if you use a really huge)
parameters = {
'use_gpu': False
......@@ -160,6 +118,10 @@ def eval_policy(env_params, checkpoint, n_eval_episodes, max_steps, action_size,
final_step = 0
skipped = 0
nb_hit = 0
agent_last_obs = {}
agent_last_action = {}
for step in range(max_steps - 1):
if allow_skipping and check_if_all_blocked(env):
# FIXME why -1? bug where all agents are "done" after max_steps!
......@@ -171,16 +133,25 @@ def eval_policy(env_params, checkpoint, n_eval_episodes, max_steps, action_size,
agent_timer.start()
for agent in env.get_agent_handles():
if obs[agent]:
preproc_timer.start()
agent_obs[agent] = normalize_observation(obs[agent], tree_depth=observation_tree_depth, observation_radius=observation_radius)
preproc_timer.end()
if obs[agent] and info['action_required'][agent]:
if agent in agent_last_obs and np.all(agent_last_obs[agent] == obs[agent]):
nb_hit += 1
action = agent_last_action[agent]
else:
preproc_timer.start()
norm_obs = normalize_observation(obs[agent], tree_depth=observation_tree_depth, observation_radius=observation_radius)
preproc_timer.end()
if info['action_required'][agent]:
inference_timer.start()
action_dict.update({agent: policy.act(agent_obs[agent], eps=0.0)})
action = policy.act(norm_obs, eps=0.0)
inference_timer.end()
action_dict.update({agent: action})
if allow_caching:
agent_last_obs[agent] = obs[agent]
agent_last_action[agent] = action
agent_timer.end()
step_timer.start()
......@@ -224,12 +195,16 @@ def eval_policy(env_params, checkpoint, n_eval_episodes, max_steps, action_size,
if skipped > 0:
skipped_text = "\t⚡ Skipped {}".format(skipped)
hit_text = ""
if nb_hit > 0:
hit_text = "\t⚡ Hit {} ({:.1f}%)".format(nb_hit, (100*nb_hit)/(n_agents*final_step))
print(
"☑️ Score: {:.3f} \tDone: {:.1f}% \tNb steps: {:.3f} "
"\t🍭 Seed: {}"
"\t🚉 Env: {:.3f}s "
"\t🤖 Agent: {:.3f}s (per step: {:.3f}s) \t[preproc: {:.3f}s \tinfer: {:.3f}s]"
"{}".format(
"{}{}".format(
normalized_score,
completion * 100.0,
final_step,
......@@ -239,14 +214,15 @@ def eval_policy(env_params, checkpoint, n_eval_episodes, max_steps, action_size,
agent_timer.get() / final_step,
preproc_timer.get(),
inference_timer.get(),
skipped_text
skipped_text,
hit_text
)
)
return scores, completions, nb_steps, agent_times, step_times
def evaluate_agents(file, n_evaluation_episodes, use_gpu, render, allow_skipping):
def evaluate_agents(file, n_evaluation_episodes, use_gpu, render, allow_skipping, allow_caching):
nb_threads = 1
eval_per_thread = n_evaluation_episodes
......@@ -262,6 +238,22 @@ def evaluate_agents(file, n_evaluation_episodes, use_gpu, render, allow_skipping
# Observation parameters need to match the ones used during training!
# small_v0
small_v0_params = {
# sample configuration
"n_agents": 5,
"x_dim": 25,
"y_dim": 25,
"n_cities": 4,
"max_rails_between_cities": 2,
"max_rails_in_city": 3,
# observations
"observation_tree_depth": 2,
"observation_radius": 10,
"observation_max_path_depth": 20
}
# Test_0
test0_params = {
# sample configuration
......@@ -273,7 +265,7 @@ def evaluate_agents(file, n_evaluation_episodes, use_gpu, render, allow_skipping
"max_rails_in_city": 3,
# observations
"observation_tree_depth": 1,
"observation_tree_depth": 2,
"observation_radius": 10,
"observation_max_path_depth": 20
}
......@@ -289,15 +281,32 @@ def evaluate_agents(file, n_evaluation_episodes, use_gpu, render, allow_skipping
"max_rails_in_city": 3,
# observations
"observation_tree_depth": 1,
"observation_tree_depth": 2,
"observation_radius": 10,
"observation_max_path_depth": 10
}
# Test_5
test5_params = {
# environment
"n_agents": 80,
"x_dim": 35,
"y_dim": 35,
"n_cities": 5,
"max_rails_between_cities": 2,
"max_rails_in_city": 4,
# observations
"observation_tree_depth": 2,
"observation_radius": 10,
"observation_max_path_depth": 20
}
params = test5_params
env_params = Namespace(**test0_params)
print("Environment parameters:")
pprint(test1_params)
pprint(params)
# Calculate space dimensions and max steps
max_steps = int(4 * 2 * (env_params.x_dim + env_params.y_dim + (env_params.n_agents / env_params.n_cities)))
......@@ -310,12 +319,12 @@ def evaluate_agents(file, n_evaluation_episodes, use_gpu, render, allow_skipping
results = []
if render:
results.append(eval_policy(test1_params, file, eval_per_thread, max_steps, action_size, state_size, 0, render, allow_skipping))
results.append(eval_policy(params, file, eval_per_thread, max_steps, action_size, state_size, 0, render, allow_skipping, allow_caching))
else:
with Pool() as p:
results = p.starmap(eval_policy,
[(test1_params, file, 1, max_steps, action_size, state_size, seed * nb_threads, render, allow_skipping)
[(params, file, 1, max_steps, action_size, state_size, seed * nb_threads, render, allow_skipping, allow_caching)
for seed in
range(total_nb_eval)])
......@@ -355,8 +364,9 @@ if __name__ == "__main__":
parser.add_argument("--use_gpu", dest="use_gpu", help="use GPU if available", action='store_true')
parser.add_argument("--render", help="render a single episode", action='store_true')
parser.add_argument("--allow_skipping", help="skips to the end of the episode if all agents are deadlocked", action='store_true')
parser.add_argument("--allow_caching", help="caches the last observation-action pair", action='store_true')
args = parser.parse_args()
os.environ["OMP_NUM_THREADS"] = str(1)
evaluate_agents(file=args.file, n_evaluation_episodes=args.n_evaluation_episodes, use_gpu=args.use_gpu, render=args.render,
allow_skipping=args.allow_skipping)
allow_skipping=args.allow_skipping, allow_caching=args.allow_caching)
......@@ -367,8 +367,6 @@ def format_action_prob(action_probs):
def eval_policy(env, policy, train_params, obs_params):
print("eval in {}x{}".format(env.width, env.height))
n_eval_episodes = train_params.n_evaluation_episodes
max_steps = env._max_episode_steps
tree_depth = obs_params.observation_tree_depth
......
......@@ -26,13 +26,18 @@ from utils.observation_utils import normalize_observation
checkpoint = "checkpoints/sample-checkpoint.pth"
# Beyond this number of agents, skip the episode
max_num_agents_handled = 80
max_num_agents_handled = 79
# Observation parameters
# These need to match your training parameters!
observation_tree_depth = 1
observation_tree_depth = 2
observation_radius = 10
observation_max_path_depth = 20
observation_max_path_depth = 30
# Use action cache:
# Saves the last observation-action mapping for each
# Only works for deterministic agents!
use_action_cache = True
####################################################
remote_client = FlatlandRemoteClient()
......@@ -42,28 +47,14 @@ predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth)
tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth, predictor=predictor)
# Calculates state and action sizes
num_features_per_node = tree_observation.observation_dim
n_nodes = sum([np.power(4, i) for i in range(observation_tree_depth + 1)])
state_size = num_features_per_node * n_nodes
state_size = tree_observation.observation_dim * n_nodes
action_size = 5
# Creates the policy. No GPU on evaluation server.
policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True)
policy.qnetwork_local = torch.load(checkpoint)
# Controller that uses the RL policy
def rl_controller(obs, number_of_agents):
action_dict = {}
for agent in range(number_of_agents):
if obs[agent] and info['action_required'][agent]:
norm_obs = normalize_observation(obs[agent], tree_depth=observation_tree_depth, observation_radius=observation_radius)
action = policy.act(norm_obs, eps=0.0)
action_dict[agent] = action
return action_dict
#####################################################################
# Main evaluation loop
#####################################################################
......@@ -105,21 +96,40 @@ while True:
time_taken_per_step = []
steps = 0
agent_last_obs = {}
agent_last_action = {}
nb_hit = 0
while True:
#####################################################################
# Evaluation of a single episode
#####################################################################
steps += 1
obs_time, agent_time, step_time = 0.0, 0.0, 0.0
no_ops_mode = False
if nb_agents <= max_num_agents_handled:# and not check_if_all_blocked(env=local_env):
if nb_agents <= max_num_agents_handled and not check_if_all_blocked(env=local_env):
time_start = time.time()
action = rl_controller(observation, nb_agents)
action_dict = {}
for agent in range(nb_agents):
if observation[agent] and info['action_required'][agent]:
if agent in agent_last_obs and np.all(agent_last_obs[agent] == observation[agent]):
action = agent_last_action[agent]
nb_hit += 1
else:
norm_obs = normalize_observation(observation[agent], tree_depth=observation_tree_depth, observation_radius=observation_radius)
action = policy.act(norm_obs, eps=0.0)
action_dict[agent] = action
if use_action_cache:
agent_last_obs[agent] = observation[agent]
agent_last_action[agent] = action
agent_time = time.time() - time_start
time_taken_by_controller.append(agent_time)
time_start = time.time()
_, all_rewards, done, info = remote_client.env_step(action)
_, all_rewards, done, info = remote_client.env_step(action_dict)
step_time = time.time() - time_start
time_taken_per_step.append(step_time)
......@@ -129,6 +139,8 @@ while True:
else:
# Too many agents or fully deadlocked: no-op to finish the episode ASAP 🏃💨
no_ops_mode = True
time_start = time.time()
_, all_rewards, done, info = remote_client.env_step({})
step_time = time.time() - time_start
......@@ -137,13 +149,15 @@ while True:
nb_agents_done = sum(done[idx] for idx in local_env.get_agent_handles())
if not done['__all__']:
print("Step {}/{}\tAgents done: {}\t Obs time {:.3f}s\t Inference time {:.3f}s\t Step time {:.3f}s".format(
print("Step {}/{}\tAgents done: {}\t Obs time {:.3f}s\t Inference time {:.5f}s\t Step time {:.3f}s\t Cache hits {}\t No-ops? {}".format(
str(steps).zfill(4),
max_nb_steps,
nb_agents_done,
obs_time,
agent_time,
step_time
step_time,
nb_hit,
no_ops_mode
), end="\r")
else:
print()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment