Commit 1397a82c authored by MasterScrat's avatar MasterScrat
Browse files

New checkpoint

parent da79652a
import os
import sys
from argparse import Namespace
from pathlib import Path
......@@ -10,6 +11,7 @@ from flatland.core.env_observation_builder import DummyObservationBuilder
from flatland.envs.observations import TreeObsForRailEnv
from flatland.evaluators.client import FlatlandRemoteClient
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.evaluators.client import TimeoutException
from utils.deadlock_check import check_if_all_blocked
......@@ -19,23 +21,14 @@ sys.path.append(str(base_dir))
from reinforcement_learning.dddqn_policy import DDDQNPolicy
from utils.observation_utils import normalize_observation
# TODO:
# - add timeout handling
# - keep only code relative to training and running the DDQN agent
# Writeup how to improve:
# - use FastTreeObsForRailEnv instead of TreeObsForRailEnv
# - add PER (using custom code, or using cpprb)
# - other improvements from Rainbow paper eg n-step... (https://arxiv.org/abs/1710.02298)
####################################################
# EVALUATION PARAMETERS
# Print detailed logs (disable when submitting)
# Print per-step logs
VERBOSE = True
# Checkpoint to use (remember to push it!)
checkpoint = "checkpoints/sample-checkpoint.pth"
checkpoint = "checkpoints/201014015722-1500.pth"
# Use last action cache
USE_ACTION_CACHE = True
......@@ -60,7 +53,11 @@ action_size = 5
# Creates the policy. No GPU on evaluation server.
policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True)
policy.qnetwork_local = torch.load(checkpoint)
if os.path.isfile(checkpoint):
policy.qnetwork_local = torch.load(checkpoint)
else:
print("Checkpoint not found, using untrained policy! (path: {})".format(checkpoint))
#####################################################################
# Main evaluation loop
......@@ -71,7 +68,7 @@ while True:
evaluation_number += 1
# We use a dummy observation and call TreeObsForRailEnv ourselves when needed.
# This way we decide if we want to calculate the observations or not instead
# This way we decide if we want to calculate the observations or not instead
# of having them calculated every time we perform an env step.
time_start = time.time()
observation, info = remote_client.env_create(
......@@ -86,6 +83,9 @@ while True:
# and hence it's safe to break out of the main evaluation loop.
break
print("Env Path : ", remote_client.current_env_path)
print("Env Creation Time : ", env_creation_time)
local_env = remote_client.env
nb_agents = len(local_env.agents)
max_nb_steps = local_env._max_episode_steps
......@@ -99,9 +99,11 @@ while True:
# Now we enter into another infinite loop where we
# compute the actions for all the individual steps in this episode
# until the episode is `done`
steps = 0
# Bookkeeping
time_taken_by_controller = []
time_taken_per_step = []
steps = 0
# Action cache: keep track of last observation to avoid running the same inferrence multiple times.
# This only makes sense for deterministic policies.
......@@ -110,83 +112,84 @@ while True:
nb_hit = 0
while True:
#####################################################################
# Evaluation of a single episode
#####################################################################
steps += 1
obs_time, agent_time, step_time = 0.0, 0.0, 0.0
no_ops_mode = False
if not check_if_all_blocked(env=local_env):
time_start = time.time()
action_dict = {}
for agent in range(nb_agents):
if observation[agent] and info['action_required'][agent]:
if agent in agent_last_obs and np.all(agent_last_obs[agent] == observation[agent]):
# cache hit
action = agent_last_action[agent]
nb_hit += 1
else:
# otherwise, run normalization and inference
norm_obs = normalize_observation(observation[agent], tree_depth=observation_tree_depth, observation_radius=observation_radius)
action = policy.act(norm_obs, eps=0.0)
action_dict[agent] = action
if USE_ACTION_CACHE:
agent_last_obs[agent] = observation[agent]
agent_last_action[agent] = action
agent_time = time.time() - time_start
time_taken_by_controller.append(agent_time)
time_start = time.time()
_, all_rewards, done, info = remote_client.env_step(action_dict)
step_time = time.time() - time_start
time_taken_per_step.append(step_time)
time_start = time.time()
observation = tree_observation.get_many(list(range(nb_agents)))
obs_time = time.time() - time_start
else:
# Fully deadlocked: perform no-ops to finish the episode ASAP 🏃💨
no_ops_mode = True
time_start = time.time()
_, all_rewards, done, info = remote_client.env_step({})
step_time = time.time() - time_start
time_taken_per_step.append(step_time)
nb_agents_done = sum(done[idx] for idx in local_env.get_agent_handles())
if not done['__all__'] and VERBOSE:
print("Step {}/{}\tAgents done: {}\t Obs time {:.3f}s\t Inference time {:.5f}s\t Step time {:.3f}s\t Cache hits {}\t No-ops? {}".format(
str(steps).zfill(4),
max_nb_steps,
nb_agents_done,
obs_time,
agent_time,
step_time,
nb_hit,
no_ops_mode
), end="\r")
else:
print()
print("Reward : ", sum(list(all_rewards.values())))
#
# When done['__all__'] == True, then the evaluation of this
# particular Env instantiation is complete, and we can break out
# of this loop, and move onto the next Env evaluation
try:
#####################################################################
# Evaluation of a single episode
#####################################################################
steps += 1
obs_time, agent_time, step_time = 0.0, 0.0, 0.0
no_ops_mode = False
if not check_if_all_blocked(env=local_env):
time_start = time.time()
action_dict = {}
for agent in range(nb_agents):
if observation[agent] and info['action_required'][agent]:
if agent in agent_last_obs and np.all(agent_last_obs[agent] == observation[agent]):
# cache hit
action = agent_last_action[agent]
nb_hit += 1
else:
# otherwise, run normalization and inference
norm_obs = normalize_observation(observation[agent], tree_depth=observation_tree_depth, observation_radius=observation_radius)
action = policy.act(norm_obs, eps=0.0)
action_dict[agent] = action
if USE_ACTION_CACHE:
agent_last_obs[agent] = observation[agent]
agent_last_action[agent] = action
agent_time = time.time() - time_start
time_taken_by_controller.append(agent_time)
time_start = time.time()
_, all_rewards, done, info = remote_client.env_step(action_dict)
step_time = time.time() - time_start
time_taken_per_step.append(step_time)
time_start = time.time()
observation = tree_observation.get_many(list(range(nb_agents)))
obs_time = time.time() - time_start
else:
# Fully deadlocked: perform no-ops
no_ops_mode = True
time_start = time.time()
_, all_rewards, done, info = remote_client.env_step({})
step_time = time.time() - time_start
time_taken_per_step.append(step_time)
nb_agents_done = sum(done[idx] for idx in local_env.get_agent_handles())
if VERBOSE or done['__all__']:
print("Step {}/{}\tAgents done: {}\t Obs time {:.3f}s\t Inference time {:.5f}s\t Step time {:.3f}s\t Cache hits {}\t No-ops? {}".format(
str(steps).zfill(4),
max_nb_steps,
nb_agents_done,
obs_time,
agent_time,
step_time,
nb_hit,
no_ops_mode
), end="\r")
if done['__all__']:
# When done['__all__'] == True, then the evaluation of this
# particular Env instantiation is complete, and we can break out
# of this loop, and move onto the next Env evaluation
print()
break
except TimeoutException as err:
# A timeout occurs, won't get any reward for this episode :-(
# Skip to next episode as further actions in this one will be ignored.
# The whole evaluation will be stopped if there are 10 consecutive timeouts.
print("Timeout! Will skip this episode and go to the next.", err)
break
np_time_taken_by_controller = np.array(time_taken_by_controller)
np_time_taken_per_step = np.array(time_taken_per_step)
print("=" * 100)
print("=" * 100)
print("Evaluation Number : ", evaluation_number)
print("Current Env Path : ", remote_client.current_env_path)
print("Env Creation Time : ", env_creation_time)
print("Number of Steps : ", steps)
print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(), np_time_taken_by_controller.std())
print("Mean/Std of Time per Step : ", np_time_taken_per_step.mean(), np_time_taken_per_step.std())
print("=" * 100)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment