diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index b9a819aacb59a432a0b8d7f14c0f208dff84fe3e..cce7ecc183e7c5af601ac6e790b95aaa8ff33ba2 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -176,7 +176,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): policy = PPOPolicy(state_size, get_action_size()) if False: policy = DeadLockAvoidanceAgent(train_env, get_action_size()) - if False: + if True: policy = MultiDecisionAgent(train_env, state_size, get_action_size(), policy) # Load existing policy diff --git a/reinforcement_learning/ppo_agent.py b/reinforcement_learning/ppo_agent.py index 43de9f7b08dcf6c49c4139415170171c97b83da3..5c0fc08c51524372d4e8597d708145cb28107d34 100644 --- a/reinforcement_learning/ppo_agent.py +++ b/reinforcement_learning/ppo_agent.py @@ -107,8 +107,8 @@ class PPOPolicy(Policy): self.weight_loss = 0.25 self.weight_entropy = 0.01 - self.buffer_size = 2_000 - self.batch_size = 64 + self.buffer_size = 32_000 + self.batch_size = 1024 self.buffer_min_size = 0 self.use_replay_buffer = True self.device = device @@ -187,7 +187,6 @@ class PPOPolicy(Policy): reward_list, state_next_list, done_list, prob_a_list) - # convert data to torch tensors states, actions, rewards, states_next, dones, prob_actions = \ torch.tensor(state_list, dtype=torch.float).to(self.device), \ diff --git a/reinforcement_learning/ppo_deadlockavoidance_agent.py b/reinforcement_learning/ppo_deadlockavoidance_agent.py index 737634c02b61db2c7ddb80424f0976d4ce9ec701..6e8880cf0a769c4dc17d0b5793510e079460ac51 100644 --- a/reinforcement_learning/ppo_deadlockavoidance_agent.py +++ b/reinforcement_learning/ppo_deadlockavoidance_agent.py @@ -39,7 +39,7 @@ class MultiDecisionAgent(Policy): act = self.dead_lock_avoidance_agent.act(handle, state, -1.0) return map_rail_env_action(act) # Agent is still at target cell - return RailEnvActions.DO_NOTHING + return map_rail_env_action(RailEnvActions.DO_NOTHING) def save(self, filename): self.dead_lock_avoidance_agent.save(filename)