Skip to content
Snippets Groups Projects
Commit 3babf29d authored by Erik Nygren's avatar Erik Nygren
Browse files

removed "bug" with reward. Attention, currently it is cheaper for an agent to...

removed "bug" with reward. Attention, currently it is cheaper for an agent to wait if we cummulate rewards between the different state!
parent df0b0ef1
No related branches found
No related tags found
No related merge requests found
...@@ -48,9 +48,9 @@ stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents ...@@ -48,9 +48,9 @@ stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents
TreeObservation = TreeObsForRailEnv(max_depth=2) TreeObservation = TreeObsForRailEnv(max_depth=2)
# Different agent types (trains) with different speeds. # Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0., # Fast passenger train speed_ration_map = {1.: 1., # Fast passenger train
1. / 2.: 0.0, # Fast freight train 1. / 2.: 0.0, # Fast freight train
1. / 3.: 1.0, # Slow commuter train 1. / 3.: 0.0, # Slow commuter train
1. / 4.: 0.0} # Slow freight train 1. / 4.: 0.0} # Slow freight train
env = RailEnv(width=x_dim, env = RailEnv(width=x_dim,
...@@ -95,7 +95,7 @@ action_prob = [0] * action_size ...@@ -95,7 +95,7 @@ action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents() agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size) agent = Agent(state_size, action_size)
with path(torch_training.Nets, "navigator_checkpoint15000.pth") as file_in: with path(torch_training.Nets, "navigator_checkpoint1000.pth") as file_in:
agent.qnetwork_local.load_state_dict(torch.load(file_in)) agent.qnetwork_local.load_state_dict(torch.load(file_in))
record_images = False record_images = False
......
...@@ -51,8 +51,8 @@ def main(argv): ...@@ -51,8 +51,8 @@ def main(argv):
TreeObservation = TreeObsForRailEnv(max_depth=2) TreeObservation = TreeObsForRailEnv(max_depth=2)
# Different agent types (trains) with different speeds. # Different agent types (trains) with different speeds.
speed_ration_map = {1.: 1., # Fast passenger train speed_ration_map = {1.: 0., # Fast passenger train
1. / 2.: 0.0, # Fast freight train 1. / 2.: 1.0, # Fast freight train
1. / 3.: 0.0, # Slow commuter train 1. / 3.: 0.0, # Slow commuter train
1. / 4.: 0.0} # Slow freight train 1. / 4.: 0.0} # Slow freight train
...@@ -106,9 +106,8 @@ def main(argv): ...@@ -106,9 +106,8 @@ def main(argv):
agent_next_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents()
agent_obs_buffer = [None] * env.get_num_agents() agent_obs_buffer = [None] * env.get_num_agents()
agent_action_buffer = [2] * env.get_num_agents() agent_action_buffer = [2] * env.get_num_agents()
agent_done_buffer = [False] * env.get_num_agents()
cummulated_reward = np.zeros(env.get_num_agents()) cummulated_reward = np.zeros(env.get_num_agents())
update_values = False
# Now we load a Double dueling DQN agent # Now we load a Double dueling DQN agent
agent = Agent(state_size, action_size) agent = Agent(state_size, action_size)
...@@ -131,39 +130,32 @@ def main(argv): ...@@ -131,39 +130,32 @@ def main(argv):
# Action # Action
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
if info['action_required'][a]: if info['action_required'][a]:
update_values = True
action = agent.act(agent_obs[a], eps=eps) action = agent.act(agent_obs[a], eps=eps)
action_prob[action] += 1 action_prob[action] += 1
else: else:
update_values = False
action = 0 action = 0
action_prob[action] += 1
action_dict.update({a: action}) action_dict.update({a: action})
# Environment step # Environment step
next_obs, all_rewards, done, info = env.step(action_dict) next_obs, all_rewards, done, info = env.step(action_dict)
# Build agent specific observations and normalize
for a in range(env.get_num_agents()):
# Penalize waiting in order to get agent to move
if env.agents[a].status == 0:
all_rewards[a] -= 1
if info['action_required'][a]:
agent_next_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10)
cummulated_reward[a] += all_rewards[a]
# Update replay buffer and train agent # Update replay buffer and train agent
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
if (info['action_required'][a] and env.agents[a].status != 3) or env.agents[a].status == 2:
if update_values or done[a]:
agent.step(agent_obs_buffer[a], agent_action_buffer[a], all_rewards[a], agent.step(agent_obs_buffer[a], agent_action_buffer[a], all_rewards[a],
agent_obs[a], agent_done_buffer[a]) agent_obs[a], done[a])
cummulated_reward[a] = 0. cummulated_reward[a] = 0.
if info['action_required'][a]:
agent_obs_buffer[a] = agent_obs[a].copy() agent_obs_buffer[a] = agent_obs[a].copy()
agent_action_buffer[a] = action_dict[a] agent_action_buffer[a] = action_dict[a]
agent_done_buffer[a] = done[a] agent_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10)
score += all_rewards[a] / env.get_num_agents() score += all_rewards[a] / env.get_num_agents()
# Copy observation # Copy observation
agent_obs = agent_next_obs.copy()
if done['__all__']: if done['__all__']:
env_done = 1 env_done = 1
break break
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment