diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index ec496404ea1586d6621747a06729cc90323c72e9..b566d5a8b37872931bc86509c266dd28dc64f800 100755
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -170,19 +170,20 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
     completion_window = deque(maxlen=checkpoint_interval)
 
     # Double Dueling DQN policy
-    policy = None
-    if False:
+    if train_params.policy == "DDDQN":
         policy = DDDQNPolicy(state_size, get_action_size(), train_params)
-    if True:
+    elif train_params.policy == "PPO":
         policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params)
-    if False:
+    elif train_params.policy == "DeadLockAvoidance":
         policy = DeadLockAvoidanceAgent(train_env, get_action_size())
-    if False:
+    elif train_params.policy == "DeadLockAvoidanceWithDecision":
         # inter_policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params)
         inter_policy = DDDQNPolicy(state_size, get_action_size(), train_params)
         policy = DeadLockAvoidanceWithDecisionAgent(train_env, state_size, get_action_size(), inter_policy)
-    if False:
+    elif train_params.policy == "MultiDecision":
         policy = MultiDecisionAgent(state_size, get_action_size(), train_params)
+    else:
+        policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params)
 
     # make sure that at least one policy is set
     if policy is None:
@@ -532,6 +533,9 @@ if __name__ == "__main__":
     parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs",
                         action='store_true')
     parser.add_argument("--max_depth", help="max depth", default=2, type=int)
+    parser.add_argument("--policy",
+                        help="policy name [DDDQN, PPO, DeadLockAvoidance, DeadLockAvoidanceWithDecision, MultiDecision]",
+                        default="ppo")
 
     training_params = parser.parse_args()
     env_params = [