From 2c63e82561ae5337b47a244cbf429ca0253734cd Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Sun, 1 Sep 2019 11:17:31 -0400
Subject: [PATCH] updated observation and training to handle multi-speed

---
 torch_training/render_agent_behavior.py | 12 ++++++------
 torch_training/training_navigation.py   | 19 ++++++++++++-------
 2 files changed, 18 insertions(+), 13 deletions(-)

diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py
index 8264db6..fc0e067 100644
--- a/torch_training/render_agent_behavior.py
+++ b/torch_training/render_agent_behavior.py
@@ -38,7 +38,7 @@ min_dist = 5
 observation_builder = TreeObsForRailEnv(max_depth=2)
 
 # Use a the malfunction generator to break agents from time to time
-stochastic_data = {'prop_malfunction': 0.1,  # Percentage of defective agents
+stochastic_data = {'prop_malfunction': 0.0,  # Percentage of defective agents
                    'malfunction_rate': 30,  # Rate of malfunction occurence
                    'min_duration': 3,  # Minimal duration of malfunction
                    'max_duration': 20  # Max duration of malfunction
@@ -48,10 +48,10 @@ stochastic_data = {'prop_malfunction': 0.1,  # Percentage of defective agents
 TreeObservation = TreeObsForRailEnv(max_depth=2)
 
 # Different agent types (trains) with different speeds.
-speed_ration_map = {1.: 0.25,  # Fast passenger train
-                    1. / 2.: 0.25,  # Fast freight train
-                    1. / 3.: 0.25,  # Slow commuter train
-                    1. / 4.: 0.25}  # Slow freight train
+speed_ration_map = {1.: 1.,  # Fast passenger train
+                    1. / 2.: 0.0,  # Fast freight train
+                    1. / 3.: 0.0,  # Slow commuter train
+                    1. / 4.: 0.0}  # Slow freight train
 
 env = RailEnv(width=x_dim,
               height=y_dim,
@@ -103,7 +103,7 @@ action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size, "FC", 0)
-with path(torch_training.Nets, "navigator_checkpoint100.pth") as file_in:
+with path(torch_training.Nets, "navigator_checkpoint1200.pth") as file_in:
     agent.qnetwork_local.load_state_dict(torch.load(file_in))
 
 record_images = False
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index d324b7c..25b8c14 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -37,7 +37,7 @@ def main(argv):
     min_dist = 5
 
     # Use a the malfunction generator to break agents from time to time
-    stochastic_data = {'prop_malfunction': 0.1,  # Percentage of defective agents
+    stochastic_data = {'prop_malfunction': 0.0,  # Percentage of defective agents
                        'malfunction_rate': 30,  # Rate of malfunction occurence
                        'min_duration': 3,  # Minimal duration of malfunction
                        'max_duration': 20  # Max duration of malfunction
@@ -47,10 +47,10 @@ def main(argv):
     TreeObservation = TreeObsForRailEnv(max_depth=2)
 
     # Different agent types (trains) with different speeds.
-    speed_ration_map = {1.: 0.25,  # Fast passenger train
-                        1. / 2.: 0.25,  # Fast freight train
-                        1. / 3.: 0.25,  # Slow commuter train
-                        1. / 4.: 0.25}  # Slow freight train
+    speed_ration_map = {1.: 1.,  # Fast passenger train
+                        1. / 2.: 0.0,  # Fast freight train
+                        1. / 3.: 0.0,  # Slow commuter train
+                        1. / 4.: 0.0}  # Slow freight train
 
     env = RailEnv(width=x_dim,
                   height=y_dim,
@@ -120,7 +120,7 @@ def main(argv):
 
         # Reset environment
         obs = env.reset(True, True)
-
+        register_action_state = np.zeros(env.get_num_agents(), dtype=bool)
         final_obs = agent_obs.copy()
         final_obs_next = agent_next_obs.copy()
 
@@ -138,6 +138,11 @@ def main(argv):
 
             # Action
             for a in range(env.get_num_agents()):
+                if env.agents[a].speed_data['position_fraction'] == 0.:
+                    register_action_state[a] = True
+                else:
+                    register_action_state[a] = False
+
                 action = agent.act(agent_obs[a], eps=eps)
                 action_prob[action] += 1
                 action_dict.update({a: action})
@@ -155,7 +160,7 @@ def main(argv):
                     final_obs[a] = agent_obs[a].copy()
                     final_obs_next[a] = agent_next_obs[a].copy()
                     final_action_dict.update({a: action_dict[a]})
-                if not done[a]:
+                if not done[a] and register_action_state[a]:
                     agent.step(agent_obs[a], action_dict[a], all_rewards[a], agent_next_obs[a], done[a])
                 score += all_rewards[a] / env.get_num_agents()
 
-- 
GitLab