Skip to content
Snippets Groups Projects
Commit 2a3839db authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

policy as argument

parent dac3eee5
No related branches found
No related tags found
No related merge requests found
...@@ -170,19 +170,20 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): ...@@ -170,19 +170,20 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
completion_window = deque(maxlen=checkpoint_interval) completion_window = deque(maxlen=checkpoint_interval)
# Double Dueling DQN policy # Double Dueling DQN policy
policy = None if train_params.policy == "DDDQN":
if False:
policy = DDDQNPolicy(state_size, get_action_size(), train_params) policy = DDDQNPolicy(state_size, get_action_size(), train_params)
if True: elif train_params.policy == "PPO":
policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params) policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params)
if False: elif train_params.policy == "DeadLockAvoidance":
policy = DeadLockAvoidanceAgent(train_env, get_action_size()) policy = DeadLockAvoidanceAgent(train_env, get_action_size())
if False: elif train_params.policy == "DeadLockAvoidanceWithDecision":
# inter_policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params) # inter_policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params)
inter_policy = DDDQNPolicy(state_size, get_action_size(), train_params) inter_policy = DDDQNPolicy(state_size, get_action_size(), train_params)
policy = DeadLockAvoidanceWithDecisionAgent(train_env, state_size, get_action_size(), inter_policy) policy = DeadLockAvoidanceWithDecisionAgent(train_env, state_size, get_action_size(), inter_policy)
if False: elif train_params.policy == "MultiDecision":
policy = MultiDecisionAgent(state_size, get_action_size(), train_params) policy = MultiDecisionAgent(state_size, get_action_size(), train_params)
else:
policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params)
# make sure that at least one policy is set # make sure that at least one policy is set
if policy is None: if policy is None:
...@@ -532,6 +533,9 @@ if __name__ == "__main__": ...@@ -532,6 +533,9 @@ if __name__ == "__main__":
parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs", parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs",
action='store_true') action='store_true')
parser.add_argument("--max_depth", help="max depth", default=2, type=int) parser.add_argument("--max_depth", help="max depth", default=2, type=int)
parser.add_argument("--policy",
help="policy name [DDDQN, PPO, DeadLockAvoidance, DeadLockAvoidanceWithDecision, MultiDecision]",
default="ppo")
training_params = parser.parse_args() training_params = parser.parse_args()
env_params = [ env_params = [
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment