Skip to content
Snippets Groups Projects
Commit 2ebc782d authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

DeadLockAvoidanceAgent -> observation

parent 78d1c9ff
No related branches found
No related tags found
No related merge requests found
...@@ -124,9 +124,7 @@ class DDDQNPolicy(Policy): ...@@ -124,9 +124,7 @@ class DDDQNPolicy(Policy):
if os.path.exists(filename + ".local") and os.path.exists(filename + ".target"): if os.path.exists(filename + ".local") and os.path.exists(filename + ".target"):
self.qnetwork_local.load_state_dict(torch.load(filename + ".local")) self.qnetwork_local.load_state_dict(torch.load(filename + ".local"))
print("qnetwork_local loaded ('{}')".format(filename + ".local")) print("qnetwork_local loaded ('{}')".format(filename + ".local"))
if self.evaluation_mode: if not self.evaluation_mode:
self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
else:
self.qnetwork_target.load_state_dict(torch.load(filename + ".target")) self.qnetwork_target.load_state_dict(torch.load(filename + ".target"))
print("qnetwork_target loaded ('{}' )".format(filename + ".target")) print("qnetwork_target loaded ('{}' )".format(filename + ".target"))
else: else:
......
...@@ -25,7 +25,7 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy ...@@ -25,7 +25,7 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy
VERBOSE = True VERBOSE = True
# Checkpoint to use (remember to push it!) # Checkpoint to use (remember to push it!)
checkpoint = "./checkpoints/201103160541-1800.pth" checkpoint = "./checkpoints/201103172118-0.pth"
# Use last action cache # Use last action cache
USE_ACTION_CACHE = True USE_ACTION_CACHE = True
......
...@@ -2,8 +2,9 @@ import numpy as np ...@@ -2,8 +2,9 @@ import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import fast_count_nonzero, fast_argmax from flatland.envs.rail_env import fast_count_nonzero, fast_argmax, RailEnvActions
from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
""" """
LICENCE for the FastTreeObs Observation Builder LICENCE for the FastTreeObs Observation Builder
...@@ -17,11 +18,12 @@ Author: Adrian Egli (adrian.egli@gmail.com) ...@@ -17,11 +18,12 @@ Author: Adrian Egli (adrian.egli@gmail.com)
[Researchgate](https://www.linkedin.com/in/adrian-egli-733a9544/) [Researchgate](https://www.linkedin.com/in/adrian-egli-733a9544/)
""" """
class FastTreeObs(ObservationBuilder): class FastTreeObs(ObservationBuilder):
def __init__(self, max_depth): def __init__(self, max_depth):
self.max_depth = max_depth self.max_depth = max_depth
self.observation_dim = 27 self.observation_dim = 30
def build_data(self): def build_data(self):
if self.env is not None: if self.env is not None:
...@@ -32,6 +34,9 @@ class FastTreeObs(ObservationBuilder): ...@@ -32,6 +34,9 @@ class FastTreeObs(ObservationBuilder):
self.debug_render_path_list = [] self.debug_render_path_list = []
if self.env is not None: if self.env is not None:
self.find_all_cell_where_agent_can_choose() self.find_all_cell_where_agent_can_choose()
self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env)
else:
self.dead_lock_avoidance_agent = None
def find_all_cell_where_agent_can_choose(self): def find_all_cell_where_agent_can_choose(self):
switches = {} switches = {}
...@@ -238,6 +243,10 @@ class FastTreeObs(ObservationBuilder): ...@@ -238,6 +243,10 @@ class FastTreeObs(ObservationBuilder):
# observation[23] : If there is a switch on the path which agent can not use -> 1 # observation[23] : If there is a switch on the path which agent can not use -> 1
# observation[24] : If there is a switch on the path which agent can not use -> 1 # observation[24] : If there is a switch on the path which agent can not use -> 1
# observation[25] : If there is a switch on the path which agent can not use -> 1 # observation[25] : If there is a switch on the path which agent can not use -> 1
# observation[26] : If there the dead-lock avoidance agent predicts a deadlock -> 1
if handle == 0:
self.dead_lock_avoidance_agent.start_step()
observation = np.zeros(self.observation_dim) observation = np.zeros(self.observation_dim)
visited = [] visited = []
...@@ -292,6 +301,12 @@ class FastTreeObs(ObservationBuilder): ...@@ -292,6 +301,12 @@ class FastTreeObs(ObservationBuilder):
observation[8] = int(agents_near_to_switch) observation[8] = int(agents_near_to_switch)
observation[9] = int(agents_near_to_switch_all) observation[9] = int(agents_near_to_switch_all)
action = self.dead_lock_avoidance_agent.act([handle],0.0)
observation[26] = int(action == RailEnvActions.STOP_MOVING)
observation[27] = int(action == RailEnvActions.MOVE_LEFT)
observation[28] = int(action == RailEnvActions.MOVE_FORWARD)
observation[29] = int(action == RailEnvActions.MOVE_RIGHT)
self.env.dev_obs_dict.update({handle: visited}) self.env.dev_obs_dict.update({handle: visited})
return observation return observation
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment