Skip to content
Snippets Groups Projects
Commit 769f25ec authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

small fix in object

parent 87273288
No related branches found
No related tags found
No related merge requests found
...@@ -257,7 +257,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): ...@@ -257,7 +257,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
# Reset environment # Reset environment
reset_timer.start() reset_timer.start()
number_of_agents = int(min(n_agents, 1 + np.floor(episode_idx / 200))) number_of_agents = int(min(n_agents, 1 + np.floor(episode_idx / 200)))
train_env_params.n_agents = 1 # episode_idx % number_of_agents + 1 train_env_params.n_agents = episode_idx % number_of_agents + 1
train_env = create_rail_env(train_env_params, tree_observation) train_env = create_rail_env(train_env_params, tree_observation)
obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True) obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True)
...@@ -314,7 +314,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): ...@@ -314,7 +314,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
next_obs, all_rewards, done, info = train_env.step(action_dict) next_obs, all_rewards, done, info = train_env.step(action_dict)
# Reward shaping .Dead-lock .NotMoving .NotStarted # Reward shaping .Dead-lock .NotMoving .NotStarted
if False: if True:
agent_positions = get_agent_positions(train_env) agent_positions = get_agent_positions(train_env)
for agent_handle in train_env.get_agent_handles(): for agent_handle in train_env.get_agent_handles():
agent = train_env.agents[agent_handle] agent = train_env.agents[agent_handle]
......
...@@ -12,11 +12,11 @@ from reinforcement_learning.policy import Policy ...@@ -12,11 +12,11 @@ from reinforcement_learning.policy import Policy
LEARNING_RATE = 0.1e-4 LEARNING_RATE = 0.1e-4
GAMMA = 0.98 GAMMA = 0.98
LMBDA = 0.9 LAMBDA = 0.9
EPS_CLIP = 0.1 SURROGATE_EPS_CLIP = 0.01
K_EPOCH = 3 K_EPOCH = 3
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cpu")#"cuda:0" if torch.cuda.is_available() else "cpu")
print("device:", device) print("device:", device)
...@@ -215,7 +215,7 @@ class PPOAgent(Policy): ...@@ -215,7 +215,7 @@ class PPOAgent(Policy):
advantage_list = [] advantage_list = []
advantage_value = 0.0 advantage_value = 0.0
for difference_to_expected_value_t in difference_to_expected_value_deltas[::-1]: for difference_to_expected_value_t in difference_to_expected_value_deltas[::-1]:
advantage_value = LMBDA * advantage_value + difference_to_expected_value_t[0] advantage_value = LAMBDA * advantage_value + difference_to_expected_value_t[0]
advantage_list.append([advantage_value]) advantage_list.append([advantage_value])
advantage_list.reverse() advantage_list.reverse()
advantages = torch.tensor(advantage_list, dtype=torch.float) advantages = torch.tensor(advantage_list, dtype=torch.float)
...@@ -227,9 +227,11 @@ class PPOAgent(Policy): ...@@ -227,9 +227,11 @@ class PPOAgent(Policy):
# Normal Policy Gradient objective # Normal Policy Gradient objective
surrogate_objective = ratios * advantages surrogate_objective = ratios * advantages
# clipped version of Normal Policy Gradient objective # clipped version of Normal Policy Gradient objective
clipped_surrogate_objective = torch.clamp(ratios * advantages, 1 - EPS_CLIP, 1 + EPS_CLIP) clipped_surrogate_objective = torch.clamp(ratios * advantages,
1 - SURROGATE_EPS_CLIP,
1 + SURROGATE_EPS_CLIP)
# create value loss function # create value loss function
value_loss = F.mse_loss(self.value_network(states), value_loss = F.smooth_l1_loss(self.value_network(states),
estimated_target_value.detach()) estimated_target_value.detach())
# create final loss function # create final loss function
loss = -torch.min(surrogate_objective, clipped_surrogate_objective) + value_loss loss = -torch.min(surrogate_objective, clipped_surrogate_objective) + value_loss
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment