Skip to content
Snippets Groups Projects
Commit cb4df49c authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

test

parent 91c77d25
No related branches found
No related tags found
No related merge requests found
......@@ -176,7 +176,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
policy = PPOPolicy(state_size, get_action_size())
if False:
policy = DeadLockAvoidanceAgent(train_env, get_action_size())
if False:
if True:
policy = MultiDecisionAgent(train_env, state_size, get_action_size(), policy)
# Load existing policy
......
......@@ -107,8 +107,8 @@ class PPOPolicy(Policy):
self.weight_loss = 0.25
self.weight_entropy = 0.01
self.buffer_size = 2_000
self.batch_size = 64
self.buffer_size = 32_000
self.batch_size = 1024
self.buffer_min_size = 0
self.use_replay_buffer = True
self.device = device
......@@ -187,7 +187,6 @@ class PPOPolicy(Policy):
reward_list, state_next_list,
done_list, prob_a_list)
# convert data to torch tensors
states, actions, rewards, states_next, dones, prob_actions = \
torch.tensor(state_list, dtype=torch.float).to(self.device), \
......
......@@ -39,7 +39,7 @@ class MultiDecisionAgent(Policy):
act = self.dead_lock_avoidance_agent.act(handle, state, -1.0)
return map_rail_env_action(act)
# Agent is still at target cell
return RailEnvActions.DO_NOTHING
return map_rail_env_action(RailEnvActions.DO_NOTHING)
def save(self, filename):
self.dead_lock_avoidance_agent.save(filename)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment