Skip to content
Snippets Groups Projects
Commit cb4df49c authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

test

parent 91c77d25
No related branches found
No related tags found
No related merge requests found
...@@ -176,7 +176,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): ...@@ -176,7 +176,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
policy = PPOPolicy(state_size, get_action_size()) policy = PPOPolicy(state_size, get_action_size())
if False: if False:
policy = DeadLockAvoidanceAgent(train_env, get_action_size()) policy = DeadLockAvoidanceAgent(train_env, get_action_size())
if False: if True:
policy = MultiDecisionAgent(train_env, state_size, get_action_size(), policy) policy = MultiDecisionAgent(train_env, state_size, get_action_size(), policy)
# Load existing policy # Load existing policy
......
...@@ -107,8 +107,8 @@ class PPOPolicy(Policy): ...@@ -107,8 +107,8 @@ class PPOPolicy(Policy):
self.weight_loss = 0.25 self.weight_loss = 0.25
self.weight_entropy = 0.01 self.weight_entropy = 0.01
self.buffer_size = 2_000 self.buffer_size = 32_000
self.batch_size = 64 self.batch_size = 1024
self.buffer_min_size = 0 self.buffer_min_size = 0
self.use_replay_buffer = True self.use_replay_buffer = True
self.device = device self.device = device
...@@ -187,7 +187,6 @@ class PPOPolicy(Policy): ...@@ -187,7 +187,6 @@ class PPOPolicy(Policy):
reward_list, state_next_list, reward_list, state_next_list,
done_list, prob_a_list) done_list, prob_a_list)
# convert data to torch tensors # convert data to torch tensors
states, actions, rewards, states_next, dones, prob_actions = \ states, actions, rewards, states_next, dones, prob_actions = \
torch.tensor(state_list, dtype=torch.float).to(self.device), \ torch.tensor(state_list, dtype=torch.float).to(self.device), \
......
...@@ -39,7 +39,7 @@ class MultiDecisionAgent(Policy): ...@@ -39,7 +39,7 @@ class MultiDecisionAgent(Policy):
act = self.dead_lock_avoidance_agent.act(handle, state, -1.0) act = self.dead_lock_avoidance_agent.act(handle, state, -1.0)
return map_rail_env_action(act) return map_rail_env_action(act)
# Agent is still at target cell # Agent is still at target cell
return RailEnvActions.DO_NOTHING return map_rail_env_action(RailEnvActions.DO_NOTHING)
def save(self, filename): def save(self, filename):
self.dead_lock_avoidance_agent.save(filename) self.dead_lock_avoidance_agent.save(filename)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment