From 3babf29d06293f6ff9995c3e0e256ce6e48821eb Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Sun, 6 Oct 2019 20:15:30 -0400
Subject: [PATCH] removed "bug" with reward. Attention, currently it is cheaper
 for an agent to wait if we cummulate rewards between the different state!

---
 torch_training/render_agent_behavior.py |  6 +++---
 torch_training/training_navigation.py   | 28 +++++++++----------------
 2 files changed, 13 insertions(+), 21 deletions(-)

diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py
index 93f9f12..d599bcf 100644
--- a/torch_training/render_agent_behavior.py
+++ b/torch_training/render_agent_behavior.py
@@ -48,9 +48,9 @@ stochastic_data = {'prop_malfunction': 0.0,  # Percentage of defective agents
 TreeObservation = TreeObsForRailEnv(max_depth=2)
 
 # Different agent types (trains) with different speeds.
-speed_ration_map = {1.: 0.,  # Fast passenger train
+speed_ration_map = {1.: 1.,  # Fast passenger train
                     1. / 2.: 0.0,  # Fast freight train
-                    1. / 3.: 1.0,  # Slow commuter train
+                    1. / 3.: 0.0,  # Slow commuter train
                     1. / 4.: 0.0}  # Slow freight train
 
 env = RailEnv(width=x_dim,
@@ -95,7 +95,7 @@ action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size)
-with path(torch_training.Nets, "navigator_checkpoint15000.pth") as file_in:
+with path(torch_training.Nets, "navigator_checkpoint1000.pth") as file_in:
     agent.qnetwork_local.load_state_dict(torch.load(file_in))
 
 record_images = False
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 229c804..252cf16 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -51,8 +51,8 @@ def main(argv):
     TreeObservation = TreeObsForRailEnv(max_depth=2)
 
     # Different agent types (trains) with different speeds.
-    speed_ration_map = {1.: 1.,  # Fast passenger train
-                        1. / 2.: 0.0,  # Fast freight train
+    speed_ration_map = {1.: 0.,  # Fast passenger train
+                        1. / 2.: 1.0,  # Fast freight train
                         1. / 3.: 0.0,  # Slow commuter train
                         1. / 4.: 0.0}  # Slow freight train
 
@@ -106,9 +106,8 @@ def main(argv):
     agent_next_obs = [None] * env.get_num_agents()
     agent_obs_buffer = [None] * env.get_num_agents()
     agent_action_buffer = [2] * env.get_num_agents()
-    agent_done_buffer = [False] * env.get_num_agents()
     cummulated_reward = np.zeros(env.get_num_agents())
-
+    update_values = False
     # Now we load a Double dueling DQN agent
     agent = Agent(state_size, action_size)
 
@@ -131,39 +130,32 @@ def main(argv):
             # Action
             for a in range(env.get_num_agents()):
                 if info['action_required'][a]:
+                    update_values = True
                     action = agent.act(agent_obs[a], eps=eps)
                     action_prob[action] += 1
                 else:
+                    update_values = False
                     action = 0
+                    action_prob[action] += 1
                 action_dict.update({a: action})
 
             # Environment step
             next_obs, all_rewards, done, info = env.step(action_dict)
-            # Build agent specific observations and normalize
-            for a in range(env.get_num_agents()):
-                # Penalize waiting in order to get agent to move
-                if env.agents[a].status == 0:
-                    all_rewards[a] -= 1
-                if info['action_required'][a]:
-                    agent_next_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10)
-                cummulated_reward[a] += all_rewards[a]
-
             # Update replay buffer and train agent
             for a in range(env.get_num_agents()):
-                if (info['action_required'][a] and env.agents[a].status != 3) or env.agents[a].status == 2:
 
+                if update_values or done[a]:
                     agent.step(agent_obs_buffer[a], agent_action_buffer[a], all_rewards[a],
-                               agent_obs[a], agent_done_buffer[a])
+                               agent_obs[a], done[a])
                     cummulated_reward[a] = 0.
-                if info['action_required'][a]:
+
                     agent_obs_buffer[a] = agent_obs[a].copy()
                     agent_action_buffer[a] = action_dict[a]
-                    agent_done_buffer[a] = done[a]
+                agent_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10)
 
                 score += all_rewards[a] / env.get_num_agents()
 
             # Copy observation
-            agent_obs = agent_next_obs.copy()
             if done['__all__']:
                 env_done = 1
                 break
-- 
GitLab