From cb4df49c3ae2036b16bbf7f2f79711d1f86c66c4 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Fri, 18 Dec 2020 11:16:23 +0100 Subject: [PATCH] test --- reinforcement_learning/multi_agent_training.py | 2 +- reinforcement_learning/ppo_agent.py | 5 ++--- reinforcement_learning/ppo_deadlockavoidance_agent.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index b9a819a..cce7ecc 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -176,7 +176,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): policy = PPOPolicy(state_size, get_action_size()) if False: policy = DeadLockAvoidanceAgent(train_env, get_action_size()) - if False: + if True: policy = MultiDecisionAgent(train_env, state_size, get_action_size(), policy) # Load existing policy diff --git a/reinforcement_learning/ppo_agent.py b/reinforcement_learning/ppo_agent.py index 43de9f7..5c0fc08 100644 --- a/reinforcement_learning/ppo_agent.py +++ b/reinforcement_learning/ppo_agent.py @@ -107,8 +107,8 @@ class PPOPolicy(Policy): self.weight_loss = 0.25 self.weight_entropy = 0.01 - self.buffer_size = 2_000 - self.batch_size = 64 + self.buffer_size = 32_000 + self.batch_size = 1024 self.buffer_min_size = 0 self.use_replay_buffer = True self.device = device @@ -187,7 +187,6 @@ class PPOPolicy(Policy): reward_list, state_next_list, done_list, prob_a_list) - # convert data to torch tensors states, actions, rewards, states_next, dones, prob_actions = \ torch.tensor(state_list, dtype=torch.float).to(self.device), \ diff --git a/reinforcement_learning/ppo_deadlockavoidance_agent.py b/reinforcement_learning/ppo_deadlockavoidance_agent.py index 737634c..6e8880c 100644 --- a/reinforcement_learning/ppo_deadlockavoidance_agent.py +++ b/reinforcement_learning/ppo_deadlockavoidance_agent.py @@ -39,7 +39,7 @@ class MultiDecisionAgent(Policy): act = self.dead_lock_avoidance_agent.act(handle, state, -1.0) return map_rail_env_action(act) # Agent is still at target cell - return RailEnvActions.DO_NOTHING + return map_rail_env_action(RailEnvActions.DO_NOTHING) def save(self, filename): self.dead_lock_avoidance_agent.save(filename) -- GitLab