From ed050703921efdab033a5898d2993be2a87a0e8c Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Sat, 19 Dec 2020 13:01:22 +0100 Subject: [PATCH] refactored --- reinforcement_learning/dddqn_policy.py | 4 ++-- reinforcement_learning/multi_agent_training.py | 6 +++--- reinforcement_learning/policy.py | 12 ++++++++++++ reinforcement_learning/ppo_agent.py | 6 +++--- .../ppo_deadlockavoidance_agent.py | 6 +++--- utils/dead_lock_avoidance_agent.py | 4 ++-- 6 files changed, 25 insertions(+), 13 deletions(-) diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py index 9864ca6..864c6a7 100644 --- a/reinforcement_learning/dddqn_policy.py +++ b/reinforcement_learning/dddqn_policy.py @@ -9,11 +9,11 @@ import torch.nn.functional as F import torch.optim as optim from reinforcement_learning.model import DuelingQNetwork -from reinforcement_learning.policy import Policy +from reinforcement_learning.policy import Policy, LearningPolicy from reinforcement_learning.replay_buffer import ReplayBuffer -class DDDQNPolicy(Policy): +class DDDQNPolicy(LearningPolicy): """Dueling Double DQN policy""" def __init__(self, state_size, action_size, in_parameters, evaluation_mode=False): diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index cce7ecc..b219ace 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -519,11 +519,11 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=12000, type=int) - parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=3, + parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=2, type=int) - parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=1, + parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=2, type=int) - parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=1, type=int) + parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=10, type=int) parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int) parser.add_argument("--eps_start", help="max exploration", default=0.1, type=float) parser.add_argument("--eps_end", help="min exploration", default=0.005, type=float) diff --git a/reinforcement_learning/policy.py b/reinforcement_learning/policy.py index 5b118ae..9b883d1 100644 --- a/reinforcement_learning/policy.py +++ b/reinforcement_learning/policy.py @@ -37,3 +37,15 @@ class Policy: def clone(self): return self + +class HeuristicPolicy(Policy): + def __init__(self): + super(HeuristicPolicy).__init__() + +class LearningPolicy(Policy): + def __init__(self): + super(LearningPolicy).__init__() + +class HybridPolicy(Policy): + def __init__(self): + super(HybridPolicy).__init__() diff --git a/reinforcement_learning/ppo_agent.py b/reinforcement_learning/ppo_agent.py index 5c0fc08..51f0f71 100644 --- a/reinforcement_learning/ppo_agent.py +++ b/reinforcement_learning/ppo_agent.py @@ -8,7 +8,7 @@ import torch.optim as optim from torch.distributions import Categorical # Hyperparameters -from reinforcement_learning.policy import Policy +from reinforcement_learning.policy import LearningPolicy from reinforcement_learning.replay_buffer import ReplayBuffer device = torch.device("cpu") # "cuda:0" if torch.cuda.is_available() else "cpu") @@ -92,10 +92,10 @@ class ActorCriticModel(nn.Module): def load(self, filename): print("load policy from file", filename) self.actor = self._load(self.actor, filename + ".actor") - self.critic = self._load(self.critic, filename + ".critic") + self.critic = self._load(self.critic, filename + ".value") -class PPOPolicy(Policy): +class PPOPolicy(LearningPolicy): def __init__(self, state_size, action_size): print(">> PPOPolicy") super(PPOPolicy, self).__init__() diff --git a/reinforcement_learning/ppo_deadlockavoidance_agent.py b/reinforcement_learning/ppo_deadlockavoidance_agent.py index 6e8880c..f891748 100644 --- a/reinforcement_learning/ppo_deadlockavoidance_agent.py +++ b/reinforcement_learning/ppo_deadlockavoidance_agent.py @@ -1,13 +1,13 @@ from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.rail_env import RailEnv, RailEnvActions -from reinforcement_learning.policy import Policy +from reinforcement_learning.policy import HybridPolicy from utils.agent_action_config import map_rail_env_action from utils.agent_can_choose_helper import AgentCanChooseHelper from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent -class MultiDecisionAgent(Policy): +class MultiDecisionAgent(HybridPolicy): def __init__(self, env: RailEnv, state_size, action_size, learning_agent): self.env = env @@ -33,7 +33,7 @@ class MultiDecisionAgent(Policy): if agent.status < RailAgentStatus.DONE: agents_on_switch, agents_near_to_switch, _, _ = \ self.agent_can_choose_helper.check_agent_decision(position, direction) - if agents_on_switch or agents_near_to_switch: + if agents_on_switch: return self.learning_agent.act(handle, state, eps) else: act = self.dead_lock_avoidance_agent.act(handle, state, -1.0) diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py index cad3b74..1f0030c 100644 --- a/utils/dead_lock_avoidance_agent.py +++ b/utils/dead_lock_avoidance_agent.py @@ -6,7 +6,7 @@ from flatland.core.env_observation_builder import DummyObservationBuilder from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.rail_env import RailEnv, RailEnvActions, fast_count_nonzero -from reinforcement_learning.policy import Policy +from reinforcement_learning.policy import HeuristicPolicy from utils.agent_action_config import map_rail_env_action from utils.shortest_distance_walker import ShortestDistanceWalker @@ -75,7 +75,7 @@ class DummyMemory: return 0 -class DeadLockAvoidanceAgent(Policy): +class DeadLockAvoidanceAgent(HeuristicPolicy): def __init__(self, env: RailEnv, action_size, show_debug_plot=False): self.env = env self.memory = DummyMemory() -- GitLab