From 0a273a9447d151e16f0fec66877364f3104f90f0 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Tue, 22 Dec 2020 09:40:09 +0100 Subject: [PATCH] Policy updated --- ... deadlockavoidance_with_decision_agent.py} | 2 +- .../multi_agent_training.py | 23 ++--- .../multi_decision_agent.py | 90 +++++++++++++++++++ reinforcement_learning/policy.py | 11 +++ run.py | 4 +- utils/dead_lock_avoidance_agent.py | 5 -- utils/deadlock_check.py | 9 ++ 7 files changed, 126 insertions(+), 18 deletions(-) rename reinforcement_learning/{ppo_deadlockavoidance_agent.py => deadlockavoidance_with_decision_agent.py} (98%) create mode 100644 reinforcement_learning/multi_decision_agent.py diff --git a/reinforcement_learning/ppo_deadlockavoidance_agent.py b/reinforcement_learning/deadlockavoidance_with_decision_agent.py similarity index 98% rename from reinforcement_learning/ppo_deadlockavoidance_agent.py rename to reinforcement_learning/deadlockavoidance_with_decision_agent.py index f891748..a1726f8 100644 --- a/reinforcement_learning/ppo_deadlockavoidance_agent.py +++ b/reinforcement_learning/deadlockavoidance_with_decision_agent.py @@ -7,7 +7,7 @@ from utils.agent_can_choose_helper import AgentCanChooseHelper from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent -class MultiDecisionAgent(HybridPolicy): +class DeadLockAvoidanceWithDecisionAgent(HybridPolicy): def __init__(self, env: RailEnv, state_size, action_size, learning_agent): self.env = env diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index 4c06133..2e22854 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -9,11 +9,10 @@ from pprint import pprint import numpy as np import psutil -from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv -from flatland.envs.rail_env import RailEnv, RailEnvActions, fast_count_nonzero +from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.utils.rendertools import RenderTool @@ -21,11 +20,10 @@ from torch.utils.tensorboard import SummaryWriter from reinforcement_learning.dddqn_policy import DDDQNPolicy from reinforcement_learning.ppo_agent import PPOPolicy -from reinforcement_learning.ppo_deadlockavoidance_agent import MultiDecisionAgent -from utils.agent_action_config import get_flatland_full_action_size, get_action_size, map_actions, map_action, \ - map_rail_env_action +from reinforcement_learning.deadlockavoidance_with_decision_agent import DeadLockAvoidanceWithDecisionAgent +from reinforcement_learning.multi_decision_agent import MultiDecisionAgent +from utils.agent_action_config import get_flatland_full_action_size, get_action_size, map_actions, map_action from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent -from utils.deadlock_check import get_agent_positions, check_for_deadlock base_dir = Path(__file__).resolve().parent.parent sys.path.append(str(base_dir)) @@ -172,13 +170,18 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): completion_window = deque(maxlen=checkpoint_interval) # Double Dueling DQN policy - policy = DDDQNPolicy(state_size, get_action_size(), train_params) - if True: + policy = None + if False: + policy = DDDQNPolicy(state_size, get_action_size(), train_params) + if False: policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params) if False: policy = DeadLockAvoidanceAgent(train_env, get_action_size()) if False: - policy = MultiDecisionAgent(train_env, state_size, get_action_size(), policy) + inter_policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params) + policy = DeadLockAvoidanceWithDecisionAgent(train_env, state_size, get_action_size(), inter_policy) + if True: + policy = MultiDecisionAgent(state_size, get_action_size(), train_params) # Load existing policy if train_params.load_policy is not "": @@ -227,7 +230,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): # Reset environment reset_timer.start() - number_of_agents = int(min(n_agents, 1 + np.floor(episode_idx / 500))) + number_of_agents = n_agents # int(min(n_agents, 1 + np.floor(episode_idx / 500))) train_env_params.n_agents = episode_idx % number_of_agents + 1 train_env = create_rail_env(train_env_params, tree_observation) diff --git a/reinforcement_learning/multi_decision_agent.py b/reinforcement_learning/multi_decision_agent.py new file mode 100644 index 0000000..13d1874 --- /dev/null +++ b/reinforcement_learning/multi_decision_agent.py @@ -0,0 +1,90 @@ +from flatland.envs.rail_env import RailEnv + +from reinforcement_learning.dddqn_policy import DDDQNPolicy +from reinforcement_learning.policy import LearningPolicy, DummyMemory +from reinforcement_learning.ppo_agent import PPOPolicy + + +class MultiDecisionAgent(LearningPolicy): + + def __init__(self, state_size, action_size, in_parameters=None): + print(">> MultiDecisionAgent") + super(MultiDecisionAgent, self).__init__() + self.state_size = state_size + self.action_size = action_size + self.in_parameters = in_parameters + self.memory = DummyMemory() + self.loss = 0 + + self.ppo_policy = PPOPolicy(state_size, action_size, use_replay_buffer=False, in_parameters=in_parameters) + self.dddqn_policy = DDDQNPolicy(state_size, action_size, in_parameters) + self.policy_selector = PPOPolicy(state_size, 2) + + + def step(self, handle, state, action, reward, next_state, done): + select = self.policy_selector.act(handle, state, 0.0) + self.ppo_policy.step(handle, state, action, reward, next_state, done) + self.dddqn_policy.step(handle, state, action, reward, next_state, done) + self.policy_selector.step(handle, state, select, reward, next_state, done) + + def act(self, handle, state, eps=0.): + select = self.policy_selector.act(handle, state, eps) + if select == 0: + return self.dddqn_policy.act(handle, state, eps) + return self.policy_selector.act(handle, state, eps) + + def save(self, filename): + self.ppo_policy.save(filename) + self.dddqn_policy.save(filename) + self.policy_selector.save(filename) + + def load(self, filename): + self.ppo_policy.load(filename) + self.dddqn_policy.load(filename) + self.policy_selector.load(filename) + + def start_step(self, train): + self.ppo_policy.start_step(train) + self.dddqn_policy.start_step(train) + self.policy_selector.start_step(train) + + def end_step(self, train): + self.ppo_policy.end_step(train) + self.dddqn_policy.end_step(train) + self.policy_selector.end_step(train) + + def start_episode(self, train): + self.ppo_policy.start_episode(train) + self.dddqn_policy.start_episode(train) + self.policy_selector.start_episode(train) + + def end_episode(self, train): + self.ppo_policy.end_episode(train) + self.dddqn_policy.end_episode(train) + self.policy_selector.end_episode(train) + + def load_replay_buffer(self, filename): + self.ppo_policy.load_replay_buffer(filename) + self.dddqn_policy.load_replay_buffer(filename) + self.policy_selector.load_replay_buffer(filename) + + def test(self): + self.ppo_policy.test() + self.dddqn_policy.test() + self.policy_selector.test() + + def reset(self, env: RailEnv): + self.ppo_policy.reset(env) + self.dddqn_policy.reset(env) + self.policy_selector.reset(env) + + def clone(self): + multi_descision_agent = MultiDecisionAgent( + self.state_size, + self.action_size, + self.in_parameters + ) + multi_descision_agent.ppo_policy = self.ppo_policy.clone() + multi_descision_agent.dddqn_policy = self.dddqn_policy.clone() + multi_descision_agent.policy_selector = self.policy_selector.clone() + return multi_descision_agent diff --git a/reinforcement_learning/policy.py b/reinforcement_learning/policy.py index 9b883d1..fe28cbc 100644 --- a/reinforcement_learning/policy.py +++ b/reinforcement_learning/policy.py @@ -1,6 +1,14 @@ from flatland.envs.rail_env import RailEnv +class DummyMemory: + def __init__(self): + self.memory = [] + + def __len__(self): + return 0 + + class Policy: def step(self, handle, state, action, reward, next_state, done): raise NotImplementedError @@ -38,14 +46,17 @@ class Policy: def clone(self): return self + class HeuristicPolicy(Policy): def __init__(self): super(HeuristicPolicy).__init__() + class LearningPolicy(Policy): def __init__(self): super(LearningPolicy).__init__() + class HybridPolicy(Policy): def __init__(self): super(HybridPolicy).__init__() diff --git a/run.py b/run.py index 6578fee..8d97053 100644 --- a/run.py +++ b/run.py @@ -31,7 +31,7 @@ from flatland.evaluators.client import FlatlandRemoteClient from flatland.evaluators.client import TimeoutException from reinforcement_learning.ppo_agent import PPOPolicy -from reinforcement_learning.ppo_deadlockavoidance_agent import MultiDecisionAgent +from reinforcement_learning.deadlockavoidance_with_decision_agent import DeadLockAvoidanceWithDecisionAgent from utils.agent_action_config import get_action_size, map_actions from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent from utils.deadlock_check import check_if_all_blocked @@ -147,7 +147,7 @@ while True: policy = trained_policy if USE_MULTI_DECISION_AGENT: - policy = MultiDecisionAgent(local_env, state_size, action_size, trained_policy) + policy = DeadLockAvoidanceWithDecisionAgent(local_env, state_size, action_size, trained_policy) policy.reset(local_env) observation = tree_observation.get_many(list(range(nb_agents))) diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py index 1f0030c..ed3a3f7 100644 --- a/utils/dead_lock_avoidance_agent.py +++ b/utils/dead_lock_avoidance_agent.py @@ -67,12 +67,7 @@ class DeadlockAvoidanceShortestDistanceWalker(ShortestDistanceWalker): self.full_shortest_distance_agent_map[(handle, position[0], position[1])] = 1 -class DummyMemory: - def __init__(self): - self.memory = [] - def __len__(self): - return 0 class DeadLockAvoidanceAgent(HeuristicPolicy): diff --git a/utils/deadlock_check.py b/utils/deadlock_check.py index d787c8c..4df6731 100644 --- a/utils/deadlock_check.py +++ b/utils/deadlock_check.py @@ -17,6 +17,15 @@ def get_agent_positions(env): return agent_positions +def get_agent_targets(env): + agent_targets = [] + for agent_handle in env.get_agent_handles(): + agent = env.agents[agent_handle] + if agent.status == RailAgentStatus.ACTIVE: + agent_targets.append(agent.target) + return agent_targets + + def check_for_deadlock(handle, env, agent_positions, check_position=None, check_direction=None): agent = env.agents[handle] if agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED: -- GitLab