From 3babf29d06293f6ff9995c3e0e256ce6e48821eb Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Sun, 6 Oct 2019 20:15:30 -0400 Subject: [PATCH] removed "bug" with reward. Attention, currently it is cheaper for an agent to wait if we cummulate rewards between the different state! --- torch_training/render_agent_behavior.py | 6 +++--- torch_training/training_navigation.py | 28 +++++++++---------------- 2 files changed, 13 insertions(+), 21 deletions(-) diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py index 93f9f12..d599bcf 100644 --- a/torch_training/render_agent_behavior.py +++ b/torch_training/render_agent_behavior.py @@ -48,9 +48,9 @@ stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents TreeObservation = TreeObsForRailEnv(max_depth=2) # Different agent types (trains) with different speeds. -speed_ration_map = {1.: 0., # Fast passenger train +speed_ration_map = {1.: 1., # Fast passenger train 1. / 2.: 0.0, # Fast freight train - 1. / 3.: 1.0, # Slow commuter train + 1. / 3.: 0.0, # Slow commuter train 1. / 4.: 0.0} # Slow freight train env = RailEnv(width=x_dim, @@ -95,7 +95,7 @@ action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size) -with path(torch_training.Nets, "navigator_checkpoint15000.pth") as file_in: +with path(torch_training.Nets, "navigator_checkpoint1000.pth") as file_in: agent.qnetwork_local.load_state_dict(torch.load(file_in)) record_images = False diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 229c804..252cf16 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -51,8 +51,8 @@ def main(argv): TreeObservation = TreeObsForRailEnv(max_depth=2) # Different agent types (trains) with different speeds. - speed_ration_map = {1.: 1., # Fast passenger train - 1. / 2.: 0.0, # Fast freight train + speed_ration_map = {1.: 0., # Fast passenger train + 1. / 2.: 1.0, # Fast freight train 1. / 3.: 0.0, # Slow commuter train 1. / 4.: 0.0} # Slow freight train @@ -106,9 +106,8 @@ def main(argv): agent_next_obs = [None] * env.get_num_agents() agent_obs_buffer = [None] * env.get_num_agents() agent_action_buffer = [2] * env.get_num_agents() - agent_done_buffer = [False] * env.get_num_agents() cummulated_reward = np.zeros(env.get_num_agents()) - + update_values = False # Now we load a Double dueling DQN agent agent = Agent(state_size, action_size) @@ -131,39 +130,32 @@ def main(argv): # Action for a in range(env.get_num_agents()): if info['action_required'][a]: + update_values = True action = agent.act(agent_obs[a], eps=eps) action_prob[action] += 1 else: + update_values = False action = 0 + action_prob[action] += 1 action_dict.update({a: action}) # Environment step next_obs, all_rewards, done, info = env.step(action_dict) - # Build agent specific observations and normalize - for a in range(env.get_num_agents()): - # Penalize waiting in order to get agent to move - if env.agents[a].status == 0: - all_rewards[a] -= 1 - if info['action_required'][a]: - agent_next_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10) - cummulated_reward[a] += all_rewards[a] - # Update replay buffer and train agent for a in range(env.get_num_agents()): - if (info['action_required'][a] and env.agents[a].status != 3) or env.agents[a].status == 2: + if update_values or done[a]: agent.step(agent_obs_buffer[a], agent_action_buffer[a], all_rewards[a], - agent_obs[a], agent_done_buffer[a]) + agent_obs[a], done[a]) cummulated_reward[a] = 0. - if info['action_required'][a]: + agent_obs_buffer[a] = agent_obs[a].copy() agent_action_buffer[a] = action_dict[a] - agent_done_buffer[a] = done[a] + agent_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10) score += all_rewards[a] / env.get_num_agents() # Copy observation - agent_obs = agent_next_obs.copy() if done['__all__']: env_done = 1 break -- GitLab