Skip to content
Snippets Groups Projects
Commit 4c1427fe authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

FastTreeObs working

parent d6103087
No related branches found
No related tags found
No related merge requests found
File deleted
......@@ -22,7 +22,7 @@ class DDDQNPolicy(Policy):
self.state_size = state_size
self.action_size = action_size
self.double_dqn = True
self.hidsize = 1
self.hidsize = 128
if not evaluation_mode:
self.hidsize = parameters.hidden_size
......@@ -34,7 +34,7 @@ class DDDQNPolicy(Policy):
self.gamma = parameters.gamma
self.buffer_min_size = parameters.buffer_min_size
# Device
# Device
if parameters.use_gpu and torch.cuda.is_available():
self.device = torch.device("cuda:0")
# print("🐇 Using GPU")
......@@ -43,7 +43,8 @@ class DDDQNPolicy(Policy):
# print("🐢 Using CPU")
# Q-Network
self.qnetwork_local = DuelingQNetwork(state_size, action_size, hidsize1=self.hidsize, hidsize2=self.hidsize).to(self.device)
self.qnetwork_local = DuelingQNetwork(state_size, action_size, hidsize1=self.hidsize, hidsize2=self.hidsize).to(
self.device)
if not evaluation_mode:
self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
......@@ -119,15 +120,22 @@ class DDDQNPolicy(Policy):
torch.save(self.qnetwork_target.state_dict(), filename + ".target")
def load(self, filename):
if os.path.exists(filename + ".local") and os.path.exists(filename + ".target"):
self.qnetwork_local.load_state_dict(torch.load(filename + ".local"))
self.qnetwork_target.load_state_dict(torch.load(filename + ".target"))
else:
if os.path.exists(filename):
self.qnetwork_local.load_state_dict(torch.load(filename))
self.qnetwork_target.load_state_dict(torch.load(filename))
try:
if os.path.exists(filename + ".local") and os.path.exists(filename + ".target"):
self.qnetwork_local.load_state_dict(torch.load(filename + ".local"))
print("qnetwork_local loaded ('{}')".format(filename + ".local"))
if self.evaluation_mode:
self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
else:
self.qnetwork_target.load_state_dict(torch.load(filename + ".target"))
print("qnetwork_target loaded ('{}' )".format(filename + ".target"))
else:
raise FileNotFoundError("Couldn't load policy from: '{}', '{}'".format(filename + ".local", filename + ".target"))
print(">> Checkpoint not found, using untrained policy! ('{}', '{}')".format(filename + ".local",
filename + ".target"))
except Exception as exc:
print(exc)
print("Couldn't load policy from, using untrained policy! ('{}', '{}')".format(filename + ".local",
filename + ".target"))
def save_replay_buffer(self, filename):
memory = self.memory.memory
......
import os
import sys
import time
from argparse import Namespace
from pathlib import Path
import numpy as np
import time
import torch
from flatland.core.env_observation_builder import DummyObservationBuilder
from flatland.envs.observations import TreeObsForRailEnv
from flatland.evaluators.client import FlatlandRemoteClient
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.evaluators.client import FlatlandRemoteClient
from flatland.evaluators.client import TimeoutException
from utils.deadlock_check import check_if_all_blocked
from utils.fast_tree_obs import FastTreeObs
base_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(base_dir))
from reinforcement_learning.dddqn_policy import DDDQNPolicy
from utils.observation_utils import normalize_observation
####################################################
# EVALUATION PARAMETERS
......@@ -28,7 +24,7 @@ from utils.observation_utils import normalize_observation
VERBOSE = True
# Checkpoint to use (remember to push it!)
checkpoint = "checkpoints/201103150429-2500.pth"
checkpoint = "./checkpoints/201103160541-1800.pth"
# Use last action cache
USE_ACTION_CACHE = True
......@@ -44,20 +40,15 @@ remote_client = FlatlandRemoteClient()
# Observation builder
predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth)
tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth, predictor=predictor)
tree_observation = FastTreeObs(max_depth=observation_tree_depth)
# Calculates state and action sizes
n_nodes = sum([np.power(4, i) for i in range(observation_tree_depth + 1)])
state_size = tree_observation.observation_dim * n_nodes
state_size = tree_observation.observation_dim
action_size = 5
# Creates the policy. No GPU on evaluation server.
policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True)
if os.path.isfile(checkpoint):
policy.load(checkpoint)
else:
print("Checkpoint not found, using untrained policy! (path: {})".format(checkpoint))
policy.load(checkpoint)
#####################################################################
# Main evaluation loop
......@@ -124,15 +115,13 @@ while True:
time_start = time.time()
action_dict = {}
for agent in range(nb_agents):
if observation[agent] and info['action_required'][agent]:
if info['action_required'][agent]:
if agent in agent_last_obs and np.all(agent_last_obs[agent] == observation[agent]):
# cache hit
action = agent_last_action[agent]
nb_hit += 1
else:
# otherwise, run normalization and inference
norm_obs = normalize_observation(observation[agent], tree_depth=observation_tree_depth, observation_radius=observation_radius)
action = policy.act(norm_obs, eps=0.0)
action = policy.act(observation[agent], eps=0.0)
action_dict[agent] = action
......@@ -163,16 +152,17 @@ while True:
nb_agents_done = sum(done[idx] for idx in local_env.get_agent_handles())
if VERBOSE or done['__all__']:
print("Step {}/{}\tAgents done: {}\t Obs time {:.3f}s\t Inference time {:.5f}s\t Step time {:.3f}s\t Cache hits {}\t No-ops? {}".format(
str(steps).zfill(4),
max_nb_steps,
nb_agents_done,
obs_time,
agent_time,
step_time,
nb_hit,
no_ops_mode
), end="\r")
print(
"Step {}/{}\tAgents done: {}\t Obs time {:.3f}s\t Inference time {:.5f}s\t Step time {:.3f}s\t Cache hits {}\t No-ops? {}".format(
str(steps).zfill(4),
max_nb_steps,
nb_agents_done,
obs_time,
agent_time,
step_time,
nb_hit,
no_ops_mode
), end="\r")
if done['__all__']:
# When done['__all__'] == True, then the evaluation of this
......@@ -190,7 +180,8 @@ while True:
np_time_taken_by_controller = np.array(time_taken_by_controller)
np_time_taken_per_step = np.array(time_taken_per_step)
print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(), np_time_taken_by_controller.std())
print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(),
np_time_taken_by_controller.std())
print("Mean/Std of Time per Step : ", np_time_taken_per_step.mean(), np_time_taken_per_step.std())
print("=" * 100)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment