diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py index 9864ca6ed1c812eb418b5c673fe0aec73b33d79a..864c6a78dd293b16aedee6fb97b7de422f8134d4 100644 --- a/reinforcement_learning/dddqn_policy.py +++ b/reinforcement_learning/dddqn_policy.py @@ -9,11 +9,11 @@ import torch.nn.functional as F import torch.optim as optim from reinforcement_learning.model import DuelingQNetwork -from reinforcement_learning.policy import Policy +from reinforcement_learning.policy import Policy, LearningPolicy from reinforcement_learning.replay_buffer import ReplayBuffer -class DDDQNPolicy(Policy): +class DDDQNPolicy(LearningPolicy): """Dueling Double DQN policy""" def __init__(self, state_size, action_size, in_parameters, evaluation_mode=False): diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index cce7ecc183e7c5af601ac6e790b95aaa8ff33ba2..b219ace5133d911600edca85da6dc13fc9590753 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -519,11 +519,11 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=12000, type=int) - parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=3, + parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=2, type=int) - parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=1, + parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=2, type=int) - parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=1, type=int) + parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=10, type=int) parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int) parser.add_argument("--eps_start", help="max exploration", default=0.1, type=float) parser.add_argument("--eps_end", help="min exploration", default=0.005, type=float) diff --git a/reinforcement_learning/policy.py b/reinforcement_learning/policy.py index 5b118aee15253d7dfb86c04925ea8a058abdbf2d..9b883d192036ae904de915d5d0804a4a951af092 100644 --- a/reinforcement_learning/policy.py +++ b/reinforcement_learning/policy.py @@ -37,3 +37,15 @@ class Policy: def clone(self): return self + +class HeuristicPolicy(Policy): + def __init__(self): + super(HeuristicPolicy).__init__() + +class LearningPolicy(Policy): + def __init__(self): + super(LearningPolicy).__init__() + +class HybridPolicy(Policy): + def __init__(self): + super(HybridPolicy).__init__() diff --git a/reinforcement_learning/ppo_agent.py b/reinforcement_learning/ppo_agent.py index 5c0fc08c51524372d4e8597d708145cb28107d34..51f0f7187f7d0894c97d535e68987b79e644779a 100644 --- a/reinforcement_learning/ppo_agent.py +++ b/reinforcement_learning/ppo_agent.py @@ -8,7 +8,7 @@ import torch.optim as optim from torch.distributions import Categorical # Hyperparameters -from reinforcement_learning.policy import Policy +from reinforcement_learning.policy import LearningPolicy from reinforcement_learning.replay_buffer import ReplayBuffer device = torch.device("cpu") # "cuda:0" if torch.cuda.is_available() else "cpu") @@ -92,10 +92,10 @@ class ActorCriticModel(nn.Module): def load(self, filename): print("load policy from file", filename) self.actor = self._load(self.actor, filename + ".actor") - self.critic = self._load(self.critic, filename + ".critic") + self.critic = self._load(self.critic, filename + ".value") -class PPOPolicy(Policy): +class PPOPolicy(LearningPolicy): def __init__(self, state_size, action_size): print(">> PPOPolicy") super(PPOPolicy, self).__init__() diff --git a/reinforcement_learning/ppo_deadlockavoidance_agent.py b/reinforcement_learning/ppo_deadlockavoidance_agent.py index 6e8880cf0a769c4dc17d0b5793510e079460ac51..f8917482fb83040c2f2a56f5ede1280a7f0481fc 100644 --- a/reinforcement_learning/ppo_deadlockavoidance_agent.py +++ b/reinforcement_learning/ppo_deadlockavoidance_agent.py @@ -1,13 +1,13 @@ from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.rail_env import RailEnv, RailEnvActions -from reinforcement_learning.policy import Policy +from reinforcement_learning.policy import HybridPolicy from utils.agent_action_config import map_rail_env_action from utils.agent_can_choose_helper import AgentCanChooseHelper from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent -class MultiDecisionAgent(Policy): +class MultiDecisionAgent(HybridPolicy): def __init__(self, env: RailEnv, state_size, action_size, learning_agent): self.env = env @@ -33,7 +33,7 @@ class MultiDecisionAgent(Policy): if agent.status < RailAgentStatus.DONE: agents_on_switch, agents_near_to_switch, _, _ = \ self.agent_can_choose_helper.check_agent_decision(position, direction) - if agents_on_switch or agents_near_to_switch: + if agents_on_switch: return self.learning_agent.act(handle, state, eps) else: act = self.dead_lock_avoidance_agent.act(handle, state, -1.0) diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py index cad3b749b43bb2fa203e861a1a90ff42114570f8..1f0030c0774b702efc35e885f92b5ce4902cc0c1 100644 --- a/utils/dead_lock_avoidance_agent.py +++ b/utils/dead_lock_avoidance_agent.py @@ -6,7 +6,7 @@ from flatland.core.env_observation_builder import DummyObservationBuilder from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.rail_env import RailEnv, RailEnvActions, fast_count_nonzero -from reinforcement_learning.policy import Policy +from reinforcement_learning.policy import HeuristicPolicy from utils.agent_action_config import map_rail_env_action from utils.shortest_distance_walker import ShortestDistanceWalker @@ -75,7 +75,7 @@ class DummyMemory: return 0 -class DeadLockAvoidanceAgent(Policy): +class DeadLockAvoidanceAgent(HeuristicPolicy): def __init__(self, env: RailEnv, action_size, show_debug_plot=False): self.env = env self.memory = DummyMemory()