diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py index 6e54f3d815caad71e18d4533cd53504fc4d31602..2cf7ad25e4f581c86cd53bc49671669b8820ab8f 100644 --- a/reinforcement_learning/dddqn_policy.py +++ b/reinforcement_learning/dddqn_policy.py @@ -124,9 +124,7 @@ class DDDQNPolicy(Policy): if os.path.exists(filename + ".local") and os.path.exists(filename + ".target"): self.qnetwork_local.load_state_dict(torch.load(filename + ".local")) print("qnetwork_local loaded ('{}')".format(filename + ".local")) - if self.evaluation_mode: - self.qnetwork_target = copy.deepcopy(self.qnetwork_local) - else: + if not self.evaluation_mode: self.qnetwork_target.load_state_dict(torch.load(filename + ".target")) print("qnetwork_target loaded ('{}' )".format(filename + ".target")) else: diff --git a/run.py b/run.py index fec8bc01cff56b1733a016666eb703ae560353c0..637bc7953f2dc9527f23c79809d0eab13c5e5268 100644 --- a/run.py +++ b/run.py @@ -25,7 +25,7 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy VERBOSE = True # Checkpoint to use (remember to push it!) -checkpoint = "./checkpoints/201103160541-1800.pth" +checkpoint = "./checkpoints/201103172118-0.pth" # Use last action cache USE_ACTION_CACHE = True diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py index 8fe6caf2535f448de561ae270ef6022d1795f12a..b2a4bf72353fc87d8dcde2f430f1277d27c422bb 100755 --- a/utils/fast_tree_obs.py +++ b/utils/fast_tree_obs.py @@ -2,8 +2,9 @@ import numpy as np from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4_utils import get_new_position from flatland.envs.agent_utils import RailAgentStatus -from flatland.envs.rail_env import fast_count_nonzero, fast_argmax +from flatland.envs.rail_env import fast_count_nonzero, fast_argmax, RailEnvActions +from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent """ LICENCE for the FastTreeObs Observation Builder @@ -17,11 +18,12 @@ Author: Adrian Egli (adrian.egli@gmail.com) [Researchgate](https://www.linkedin.com/in/adrian-egli-733a9544/) """ + class FastTreeObs(ObservationBuilder): def __init__(self, max_depth): self.max_depth = max_depth - self.observation_dim = 27 + self.observation_dim = 30 def build_data(self): if self.env is not None: @@ -32,6 +34,9 @@ class FastTreeObs(ObservationBuilder): self.debug_render_path_list = [] if self.env is not None: self.find_all_cell_where_agent_can_choose() + self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env) + else: + self.dead_lock_avoidance_agent = None def find_all_cell_where_agent_can_choose(self): switches = {} @@ -238,6 +243,10 @@ class FastTreeObs(ObservationBuilder): # observation[23] : If there is a switch on the path which agent can not use -> 1 # observation[24] : If there is a switch on the path which agent can not use -> 1 # observation[25] : If there is a switch on the path which agent can not use -> 1 + # observation[26] : If there the dead-lock avoidance agent predicts a deadlock -> 1 + + if handle == 0: + self.dead_lock_avoidance_agent.start_step() observation = np.zeros(self.observation_dim) visited = [] @@ -292,6 +301,12 @@ class FastTreeObs(ObservationBuilder): observation[8] = int(agents_near_to_switch) observation[9] = int(agents_near_to_switch_all) + action = self.dead_lock_avoidance_agent.act([handle],0.0) + observation[26] = int(action == RailEnvActions.STOP_MOVING) + observation[27] = int(action == RailEnvActions.MOVE_LEFT) + observation[28] = int(action == RailEnvActions.MOVE_FORWARD) + observation[29] = int(action == RailEnvActions.MOVE_RIGHT) + self.env.dev_obs_dict.update({handle: visited}) - return observation \ No newline at end of file + return observation