From 388822a00bd0cf19707bafa91da120a26f7a0180 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Tue, 22 Dec 2020 10:57:08 +0100 Subject: [PATCH] Policy updated --- .../deadlockavoidance_with_decision_agent.py | 4 +++- reinforcement_learning/multi_agent_training.py | 10 +++++++--- utils/dead_lock_avoidance_agent.py | 6 +----- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/reinforcement_learning/deadlockavoidance_with_decision_agent.py b/reinforcement_learning/deadlockavoidance_with_decision_agent.py index a1726f8..e9a6f8e 100644 --- a/reinforcement_learning/deadlockavoidance_with_decision_agent.py +++ b/reinforcement_learning/deadlockavoidance_with_decision_agent.py @@ -10,6 +10,8 @@ from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent class DeadLockAvoidanceWithDecisionAgent(HybridPolicy): def __init__(self, env: RailEnv, state_size, action_size, learning_agent): + print(">> DeadLockAvoidanceWithDecisionAgent") + super(DeadLockAvoidanceWithDecisionAgent, self).__init__() self.env = env self.state_size = state_size self.action_size = action_size @@ -33,7 +35,7 @@ class DeadLockAvoidanceWithDecisionAgent(HybridPolicy): if agent.status < RailAgentStatus.DONE: agents_on_switch, agents_near_to_switch, _, _ = \ self.agent_can_choose_helper.check_agent_decision(position, direction) - if agents_on_switch: + if agents_on_switch or agents_near_to_switch: return self.learning_agent.act(handle, state, eps) else: act = self.dead_lock_avoidance_agent.act(handle, state, -1.0) diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index 2e22854..6734f9c 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -177,12 +177,16 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params) if False: policy = DeadLockAvoidanceAgent(train_env, get_action_size()) - if False: - inter_policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params) - policy = DeadLockAvoidanceWithDecisionAgent(train_env, state_size, get_action_size(), inter_policy) if True: + inter_policy = DDDQNPolicy(state_size, get_action_size(), train_params) + policy = DeadLockAvoidanceWithDecisionAgent(train_env, state_size, get_action_size(), inter_policy) + if False: policy = MultiDecisionAgent(state_size, get_action_size(), train_params) + # make sure that at least one policy is set + if policy is None: + policy = DDDQNPolicy(state_size, get_action_size(), train_params) + # Load existing policy if train_params.load_policy is not "": policy.load(train_params.load_policy) diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py index ed3a3f7..ac7fd0c 100644 --- a/utils/dead_lock_avoidance_agent.py +++ b/utils/dead_lock_avoidance_agent.py @@ -6,7 +6,7 @@ from flatland.core.env_observation_builder import DummyObservationBuilder from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.rail_env import RailEnv, RailEnvActions, fast_count_nonzero -from reinforcement_learning.policy import HeuristicPolicy +from reinforcement_learning.policy import HeuristicPolicy, DummyMemory from utils.agent_action_config import map_rail_env_action from utils.shortest_distance_walker import ShortestDistanceWalker @@ -66,10 +66,6 @@ class DeadlockAvoidanceShortestDistanceWalker(ShortestDistanceWalker): self.shortest_distance_agent_map[(handle, position[0], position[1])] = 1 self.full_shortest_distance_agent_map[(handle, position[0], position[1])] = 1 - - - - class DeadLockAvoidanceAgent(HeuristicPolicy): def __init__(self, env: RailEnv, action_size, show_debug_plot=False): self.env = env -- GitLab