diff --git a/checkpoints/201014015722-1500.pth b/checkpoints/201014015722-1500.pth deleted file mode 100644 index cbef409414195189caa8e536c185c7059e42cefe..0000000000000000000000000000000000000000 Binary files a/checkpoints/201014015722-1500.pth and /dev/null differ diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py index c1177b91bc621e517a8706729370f502566b80e6..6e54f3d815caad71e18d4533cd53504fc4d31602 100644 --- a/reinforcement_learning/dddqn_policy.py +++ b/reinforcement_learning/dddqn_policy.py @@ -22,7 +22,7 @@ class DDDQNPolicy(Policy): self.state_size = state_size self.action_size = action_size self.double_dqn = True - self.hidsize = 1 + self.hidsize = 128 if not evaluation_mode: self.hidsize = parameters.hidden_size @@ -34,7 +34,7 @@ class DDDQNPolicy(Policy): self.gamma = parameters.gamma self.buffer_min_size = parameters.buffer_min_size - # Device + # Device if parameters.use_gpu and torch.cuda.is_available(): self.device = torch.device("cuda:0") # print("🇠Using GPU") @@ -43,7 +43,8 @@ class DDDQNPolicy(Policy): # print("🢠Using CPU") # Q-Network - self.qnetwork_local = DuelingQNetwork(state_size, action_size, hidsize1=self.hidsize, hidsize2=self.hidsize).to(self.device) + self.qnetwork_local = DuelingQNetwork(state_size, action_size, hidsize1=self.hidsize, hidsize2=self.hidsize).to( + self.device) if not evaluation_mode: self.qnetwork_target = copy.deepcopy(self.qnetwork_local) @@ -119,15 +120,22 @@ class DDDQNPolicy(Policy): torch.save(self.qnetwork_target.state_dict(), filename + ".target") def load(self, filename): - if os.path.exists(filename + ".local") and os.path.exists(filename + ".target"): - self.qnetwork_local.load_state_dict(torch.load(filename + ".local")) - self.qnetwork_target.load_state_dict(torch.load(filename + ".target")) - else: - if os.path.exists(filename): - self.qnetwork_local.load_state_dict(torch.load(filename)) - self.qnetwork_target.load_state_dict(torch.load(filename)) + try: + if os.path.exists(filename + ".local") and os.path.exists(filename + ".target"): + self.qnetwork_local.load_state_dict(torch.load(filename + ".local")) + print("qnetwork_local loaded ('{}')".format(filename + ".local")) + if self.evaluation_mode: + self.qnetwork_target = copy.deepcopy(self.qnetwork_local) + else: + self.qnetwork_target.load_state_dict(torch.load(filename + ".target")) + print("qnetwork_target loaded ('{}' )".format(filename + ".target")) else: - raise FileNotFoundError("Couldn't load policy from: '{}', '{}'".format(filename + ".local", filename + ".target")) + print(">> Checkpoint not found, using untrained policy! ('{}', '{}')".format(filename + ".local", + filename + ".target")) + except Exception as exc: + print(exc) + print("Couldn't load policy from, using untrained policy! ('{}', '{}')".format(filename + ".local", + filename + ".target")) def save_replay_buffer(self, filename): memory = self.memory.memory diff --git a/run.py b/run.py index 918d78cbc0b7c76a2414c0b431e83cda1d1ef043..f11c94168e9ba1856665273e716361d1a2077b50 100644 --- a/run.py +++ b/run.py @@ -1,25 +1,21 @@ -import os import sys +import time from argparse import Namespace from pathlib import Path import numpy as np -import time - -import torch from flatland.core.env_observation_builder import DummyObservationBuilder -from flatland.envs.observations import TreeObsForRailEnv -from flatland.evaluators.client import FlatlandRemoteClient from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.evaluators.client import FlatlandRemoteClient from flatland.evaluators.client import TimeoutException from utils.deadlock_check import check_if_all_blocked +from utils.fast_tree_obs import FastTreeObs base_dir = Path(__file__).resolve().parent.parent sys.path.append(str(base_dir)) from reinforcement_learning.dddqn_policy import DDDQNPolicy -from utils.observation_utils import normalize_observation #################################################### # EVALUATION PARAMETERS @@ -28,7 +24,7 @@ from utils.observation_utils import normalize_observation VERBOSE = True # Checkpoint to use (remember to push it!) -checkpoint = "checkpoints/201103150429-2500.pth" +checkpoint = "./checkpoints/201103160541-1800.pth" # Use last action cache USE_ACTION_CACHE = True @@ -44,20 +40,15 @@ remote_client = FlatlandRemoteClient() # Observation builder predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth) -tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth, predictor=predictor) +tree_observation = FastTreeObs(max_depth=observation_tree_depth) # Calculates state and action sizes -n_nodes = sum([np.power(4, i) for i in range(observation_tree_depth + 1)]) -state_size = tree_observation.observation_dim * n_nodes +state_size = tree_observation.observation_dim action_size = 5 # Creates the policy. No GPU on evaluation server. policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True) - -if os.path.isfile(checkpoint): - policy.load(checkpoint) -else: - print("Checkpoint not found, using untrained policy! (path: {})".format(checkpoint)) +policy.load(checkpoint) ##################################################################### # Main evaluation loop @@ -124,15 +115,13 @@ while True: time_start = time.time() action_dict = {} for agent in range(nb_agents): - if observation[agent] and info['action_required'][agent]: + if info['action_required'][agent]: if agent in agent_last_obs and np.all(agent_last_obs[agent] == observation[agent]): # cache hit action = agent_last_action[agent] nb_hit += 1 else: - # otherwise, run normalization and inference - norm_obs = normalize_observation(observation[agent], tree_depth=observation_tree_depth, observation_radius=observation_radius) - action = policy.act(norm_obs, eps=0.0) + action = policy.act(observation[agent], eps=0.0) action_dict[agent] = action @@ -163,16 +152,17 @@ while True: nb_agents_done = sum(done[idx] for idx in local_env.get_agent_handles()) if VERBOSE or done['__all__']: - print("Step {}/{}\tAgents done: {}\t Obs time {:.3f}s\t Inference time {:.5f}s\t Step time {:.3f}s\t Cache hits {}\t No-ops? {}".format( - str(steps).zfill(4), - max_nb_steps, - nb_agents_done, - obs_time, - agent_time, - step_time, - nb_hit, - no_ops_mode - ), end="\r") + print( + "Step {}/{}\tAgents done: {}\t Obs time {:.3f}s\t Inference time {:.5f}s\t Step time {:.3f}s\t Cache hits {}\t No-ops? {}".format( + str(steps).zfill(4), + max_nb_steps, + nb_agents_done, + obs_time, + agent_time, + step_time, + nb_hit, + no_ops_mode + ), end="\r") if done['__all__']: # When done['__all__'] == True, then the evaluation of this @@ -190,7 +180,8 @@ while True: np_time_taken_by_controller = np.array(time_taken_by_controller) np_time_taken_per_step = np.array(time_taken_per_step) - print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(), np_time_taken_by_controller.std()) + print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(), + np_time_taken_by_controller.std()) print("Mean/Std of Time per Step : ", np_time_taken_per_step.mean(), np_time_taken_per_step.std()) print("=" * 100)