From e28c57e5a9c17aba9415d3f80d7cb5e6a998d9c9 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Tue, 5 Jan 2021 10:25:35 +0100 Subject: [PATCH] Policy updated --- .../deadlockavoidance_with_decision_agent.py | 36 +++++----- .../multi_agent_training.py | 3 +- .../multi_decision_agent.py | 2 +- utils/fast_tree_obs.py | 72 ++++++++----------- 4 files changed, 51 insertions(+), 62 deletions(-) diff --git a/reinforcement_learning/deadlockavoidance_with_decision_agent.py b/reinforcement_learning/deadlockavoidance_with_decision_agent.py index e9a6f8e..550e73e 100644 --- a/reinforcement_learning/deadlockavoidance_with_decision_agent.py +++ b/reinforcement_learning/deadlockavoidance_with_decision_agent.py @@ -2,8 +2,8 @@ from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.rail_env import RailEnv, RailEnvActions from reinforcement_learning.policy import HybridPolicy +from reinforcement_learning.ppo_agent import PPOPolicy from utils.agent_action_config import map_rail_env_action -from utils.agent_can_choose_helper import AgentCanChooseHelper from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent @@ -17,69 +17,69 @@ class DeadLockAvoidanceWithDecisionAgent(HybridPolicy): self.action_size = action_size self.learning_agent = learning_agent self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, action_size, False) - self.agent_can_choose_helper = AgentCanChooseHelper() + self.policy_selector = PPOPolicy(state_size, 2) + self.memory = self.learning_agent.memory self.loss = self.learning_agent.loss def step(self, handle, state, action, reward, next_state, done): + select = self.policy_selector.act(handle, state, 0.0) + self.policy_selector.step(handle, state, select, reward, next_state, done) self.dead_lock_avoidance_agent.step(handle, state, action, reward, next_state, done) self.learning_agent.step(handle, state, action, reward, next_state, done) self.loss = self.learning_agent.loss def act(self, handle, state, eps=0.): - agent = self.env.agents[handle] - position = agent.position - if position is None: - position = agent.initial_position - direction = agent.direction - if agent.status < RailAgentStatus.DONE: - agents_on_switch, agents_near_to_switch, _, _ = \ - self.agent_can_choose_helper.check_agent_decision(position, direction) - if agents_on_switch or agents_near_to_switch: - return self.learning_agent.act(handle, state, eps) - else: - act = self.dead_lock_avoidance_agent.act(handle, state, -1.0) - return map_rail_env_action(act) - # Agent is still at target cell - return map_rail_env_action(RailEnvActions.DO_NOTHING) + select = self.policy_selector.act(handle, state, eps) + if select == 0: + return self.learning_agent.act(handle, state, eps) + return self.dead_lock_avoidance_agent.act(handle, state, -1.0) def save(self, filename): self.dead_lock_avoidance_agent.save(filename) self.learning_agent.save(filename) + self.policy_selector.save(filename + '.selector') def load(self, filename): self.dead_lock_avoidance_agent.load(filename) self.learning_agent.load(filename) + self.policy_selector.load(filename + '.selector') def start_step(self, train): self.dead_lock_avoidance_agent.start_step(train) self.learning_agent.start_step(train) + self.policy_selector.start_step(train) def end_step(self, train): self.dead_lock_avoidance_agent.end_step(train) self.learning_agent.end_step(train) + self.policy_selector.end_step(train) def start_episode(self, train): self.dead_lock_avoidance_agent.start_episode(train) self.learning_agent.start_episode(train) + self.policy_selector.start_episode(train) def end_episode(self, train): self.dead_lock_avoidance_agent.end_episode(train) self.learning_agent.end_episode(train) + self.policy_selector.end_episode(train) def load_replay_buffer(self, filename): self.dead_lock_avoidance_agent.load_replay_buffer(filename) self.learning_agent.load_replay_buffer(filename) + self.policy_selector.load_replay_buffer(filename + ".selector") def test(self): self.dead_lock_avoidance_agent.test() self.learning_agent.test() + self.policy_selector.test() def reset(self, env: RailEnv): self.env = env - self.agent_can_choose_helper.build_data(env) self.dead_lock_avoidance_agent.reset(env) self.learning_agent.reset(env) + self.policy_selector.reset(env) def clone(self): return self diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index 6734f9c..882cf86 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -178,6 +178,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): if False: policy = DeadLockAvoidanceAgent(train_env, get_action_size()) if True: + # inter_policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params) inter_policy = DDDQNPolicy(state_size, get_action_size(), train_params) policy = DeadLockAvoidanceWithDecisionAgent(train_env, state_size, get_action_size(), inter_policy) if False: @@ -234,7 +235,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): # Reset environment reset_timer.start() - number_of_agents = n_agents # int(min(n_agents, 1 + np.floor(episode_idx / 500))) + number_of_agents = int(min(n_agents, 1 + np.floor(episode_idx / 200))) train_env_params.n_agents = episode_idx % number_of_agents + 1 train_env = create_rail_env(train_env_params, tree_observation) diff --git a/reinforcement_learning/multi_decision_agent.py b/reinforcement_learning/multi_decision_agent.py index 13d1874..5047bcd 100644 --- a/reinforcement_learning/multi_decision_agent.py +++ b/reinforcement_learning/multi_decision_agent.py @@ -22,9 +22,9 @@ class MultiDecisionAgent(LearningPolicy): def step(self, handle, state, action, reward, next_state, done): - select = self.policy_selector.act(handle, state, 0.0) self.ppo_policy.step(handle, state, action, reward, next_state, done) self.dddqn_policy.step(handle, state, action, reward, next_state, done) + select = self.policy_selector.act(handle, state, 0.0) self.policy_selector.step(handle, state, select, reward, next_state, done) def act(self, handle, state, eps=0.): diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py index 8172703..f0b6277 100755 --- a/utils/fast_tree_obs.py +++ b/utils/fast_tree_obs.py @@ -4,11 +4,10 @@ import numpy as np from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4_utils import get_new_position from flatland.envs.agent_utils import RailAgentStatus -from flatland.envs.rail_env import fast_count_nonzero, fast_argmax, RailEnvActions +from flatland.envs.rail_env import fast_count_nonzero, fast_argmax from utils.agent_can_choose_helper import AgentCanChooseHelper -from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent -from utils.deadlock_check import check_for_deadlock, get_agent_positions +from utils.deadlock_check import get_agent_positions, get_agent_targets """ LICENCE for the FastTreeObs Observation Builder @@ -27,7 +26,7 @@ class FastTreeObs(ObservationBuilder): def __init__(self, max_depth: Any): self.max_depth = max_depth - self.observation_dim = 41 + self.observation_dim = 30 self.agent_can_choose_helper = None def debug_render(self, env_renderer): @@ -65,21 +64,18 @@ class FastTreeObs(ObservationBuilder): self.agent_can_choose_helper.build_data(self.env) self.debug_render_list = [] self.debug_render_path_list = [] - if self.env is not None: - self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, 5, False) - else: - self.dead_lock_avoidance_agent = None def _explore(self, handle, new_position, new_direction, distance_map, depth=0): has_opp_agent = 0 has_same_agent = 0 has_target = 0 + has_opp_target = 0 visited = [] min_dist = distance_map[handle, new_position[0], new_position[1], new_direction] # stop exploring (max_depth reached) if depth >= self.max_depth: - return has_opp_agent, has_same_agent, has_target, visited, min_dist + return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist # max_explore_steps = 100 -> just to ensure that the exploration ends cnt = 0 @@ -92,7 +88,7 @@ class FastTreeObs(ObservationBuilder): if self.env.agents[opp_a].direction != new_direction: # opp agent found -> stop exploring. This would be a strong signal. has_opp_agent = 1 - return has_opp_agent, has_same_agent, has_target, visited, min_dist + return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist else: # same agent found # the agent can follow the agent, because this agent is still moving ahead and there shouldn't @@ -101,7 +97,8 @@ class FastTreeObs(ObservationBuilder): # target on this branch -> thus the agents should scan further whether there will be an opposite # agent walking on same track has_same_agent = 1 - # !NOT stop exploring! return has_opp_agent, has_same_agent, has_switch, visited,min_dist + # !NOT stop exploring! + return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist # agents_on_switch == TRUE -> Current cell is a switch where the agent can decide (branch) in exploration # agent_near_to_switch == TRUE -> One cell before the switch, where the agent can decide @@ -112,10 +109,14 @@ class FastTreeObs(ObservationBuilder): if agents_near_to_switch: # The exploration was walking on a path where the agent can not decide # Best option would be MOVE_FORWARD -> Skip exploring - just walking - return has_opp_agent, has_same_agent, has_target, visited, min_dist + return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist + + if self.env.agents[handle].target in self.agents_target: + has_opp_target = 1 if self.env.agents[handle].target == new_position: has_target = 1 + return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist possible_transitions = self.env.rail.get_transitions(*new_position, new_direction) if agents_on_switch: @@ -130,30 +131,30 @@ class FastTreeObs(ObservationBuilder): # --- OPEN RESEARCH QUESTION ---> is this good or shall we use full detailed information as # we did in the TreeObservation (FLATLAND) ? if possible_transitions[dir_loop] == 1: - hoa, hsa, ht, v, m_dist = self._explore(handle, - get_new_position(new_position, dir_loop), - dir_loop, - distance_map, - depth + 1) + hoa, hsa, ht, hot, v, m_dist = self._explore(handle, + get_new_position(new_position, dir_loop), + dir_loop, + distance_map, + depth + 1) visited.append(v) - has_opp_agent += max(hoa, has_opp_agent) - has_same_agent += max(hsa, has_same_agent) + has_opp_agent = max(hoa, has_opp_agent) + has_same_agent = max(hsa, has_same_agent) has_target = max(has_target, ht) + has_opp_target = max(has_opp_target, hot) min_dist = min(min_dist, m_dist) - return has_opp_agent, has_same_agent, has_target, visited, min_dist + return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist else: new_direction = fast_argmax(possible_transitions) new_position = get_new_position(new_position, new_direction) min_dist = min(min_dist, distance_map[handle, new_position[0], new_position[1], new_direction]) - return has_opp_agent, has_same_agent, has_target, visited, min_dist + return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist def get_many(self, handles: Optional[List[int]] = None): - self.dead_lock_avoidance_agent.start_step(train=False) self.agent_positions = get_agent_positions(self.env) + self.agents_target = get_agent_targets(self.env) observations = super().get_many(handles) - self.dead_lock_avoidance_agent.end_step(train=False) return observations def get(self, handle: int = 0): @@ -184,8 +185,6 @@ class FastTreeObs(ObservationBuilder): # observation[23] : If there is a switch on the path which agent can not use -> 1 # observation[24] : If there is a switch on the path which agent can not use -> 1 # observation[25] : If there is a switch on the path which agent can not use -> 1 - # observation[26] : If there the dead-lock avoidance agent predicts a deadlock -> 1 - # observation[27] : If there the agent can only walk forward or stop -> 1 observation = np.zeros(self.observation_dim) visited = [] @@ -223,24 +222,18 @@ class FastTreeObs(ObservationBuilder): if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)): observation[dir_loop] = int(new_cell_dist < current_cell_dist) - has_opp_agent, has_same_agent, has_target, v, min_dist = self._explore(handle, - new_position, - branch_direction, - distance_map) + has_opp_agent, has_same_agent, has_target, has_opp_target, v, min_dist = self._explore(handle, + new_position, + branch_direction, + distance_map) visited.append(v) if not (np.math.isinf(min_dist) and np.math.isinf(current_cell_dist)): - observation[31 + dir_loop] = int(min_dist < current_cell_dist) - observation[11 + dir_loop] = int(not np.math.isinf(new_cell_dist)) + observation[11 + dir_loop] = int(min_dist < current_cell_dist) observation[15 + dir_loop] = has_opp_agent observation[19 + dir_loop] = has_same_agent observation[23 + dir_loop] = has_target - observation[27 + dir_loop] = int(np.math.isinf(new_cell_dist)) - observation[36] = int(check_for_deadlock(handle, - self.env, - self.agent_positions, - new_position, - branch_direction)) + observation[27 + dir_loop] = has_opp_target agents_on_switch, \ agents_near_to_switch, \ @@ -253,11 +246,6 @@ class FastTreeObs(ObservationBuilder): observation[9] = int(agents_near_to_switch) observation[10] = int(agents_near_to_switch_all) - action = self.dead_lock_avoidance_agent.act(handle, None, 0.0) - observation[35] = int(action == RailEnvActions.STOP_MOVING) - - observation[40] = int(check_for_deadlock(handle, self.env, self.agent_positions)) - self.env.dev_obs_dict.update({handle: visited}) observation[np.isinf(observation)] = -1 -- GitLab