diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index ec496404ea1586d6621747a06729cc90323c72e9..b566d5a8b37872931bc86509c266dd28dc64f800 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -170,19 +170,20 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): completion_window = deque(maxlen=checkpoint_interval) # Double Dueling DQN policy - policy = None - if False: + if train_params.policy == "DDDQN": policy = DDDQNPolicy(state_size, get_action_size(), train_params) - if True: + elif train_params.policy == "PPO": policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params) - if False: + elif train_params.policy == "DeadLockAvoidance": policy = DeadLockAvoidanceAgent(train_env, get_action_size()) - if False: + elif train_params.policy == "DeadLockAvoidanceWithDecision": # inter_policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params) inter_policy = DDDQNPolicy(state_size, get_action_size(), train_params) policy = DeadLockAvoidanceWithDecisionAgent(train_env, state_size, get_action_size(), inter_policy) - if False: + elif train_params.policy == "MultiDecision": policy = MultiDecisionAgent(state_size, get_action_size(), train_params) + else: + policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params) # make sure that at least one policy is set if policy is None: @@ -532,6 +533,9 @@ if __name__ == "__main__": parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs", action='store_true') parser.add_argument("--max_depth", help="max depth", default=2, type=int) + parser.add_argument("--policy", + help="policy name [DDDQN, PPO, DeadLockAvoidance, DeadLockAvoidanceWithDecision, MultiDecision]", + default="ppo") training_params = parser.parse_args() env_params = [