From caada6f1b1f45259fca5b752864433e84564eec4 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Sat, 5 Oct 2019 11:25:19 -0400
Subject: [PATCH] updated to new reset behavior

---
 scoring/utils/misc_utils.py                   |  5 ++--
 sequential_agent/run_test.py                  |  4 +--
 torch_training/Getting_Started_Training.md    |  2 +-
 torch_training/Multi_Agent_Training_Intro.md  |  2 +-
 torch_training/multi_agent_inference.py       | 13 +++++----
 torch_training/multi_agent_training.py        |  2 +-
 .../multi_agent_two_time_step_training.py     |  2 +-
 torch_training/render_agent_behavior.py       |  8 +++---
 torch_training/training_navigation.py         | 27 +++++++------------
 utils/misc_utils.py                           |  6 ++---
 10 files changed, 31 insertions(+), 40 deletions(-)

diff --git a/scoring/utils/misc_utils.py b/scoring/utils/misc_utils.py
index dee5f47..6f10af9 100644
--- a/scoring/utils/misc_utils.py
+++ b/scoring/utils/misc_utils.py
@@ -2,7 +2,6 @@ import random
 import time
 
 import numpy as np
-
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
@@ -66,7 +65,7 @@ def run_test(parameters, agent, observation_builder=None, observation_wrapper=No
                       number_of_agents=1,
                       )
 
-        obs = env.reset()
+        obs, info = env.reset()
 
         if observation_wrapper is not None:
             for a in range(env.get_num_agents()):
@@ -181,7 +180,7 @@ def run_test_sequential(parameters, agent, test_nr=0, tree_depth=3):
                       number_of_agents=1,
                       )
 
-        obs = env.reset()
+        obs, info = env.reset()
         done = env.dones
         # Run episode
         trial_score = 0
diff --git a/sequential_agent/run_test.py b/sequential_agent/run_test.py
index a8c0bbe..92e8145 100644
--- a/sequential_agent/run_test.py
+++ b/sequential_agent/run_test.py
@@ -1,11 +1,11 @@
 import numpy as np
-
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
 from flatland.envs.schedule_generators import complex_schedule_generator
 from flatland.utils.rendertools import RenderTool
+
 from sequential_agent.simple_order_agent import OrderedAgent
 
 np.random.seed(2)
@@ -49,7 +49,7 @@ action_dict = dict()
 for trials in range(1, n_trials + 1):
 
     # Reset environment
-    obs = env.reset(True, True)
+    obs, info = env.reset(True, True)
     done = env.dones
     env_renderer.reset()
     frame_step = 0
diff --git a/torch_training/Getting_Started_Training.md b/torch_training/Getting_Started_Training.md
index 8610bfd..cbf4a3c 100644
--- a/torch_training/Getting_Started_Training.md
+++ b/torch_training/Getting_Started_Training.md
@@ -150,7 +150,7 @@ We now use the normalized `agent_obs` for our training loop:
 for trials in range(1, n_trials + 1):
 
     # Reset environment
-    obs = env.reset(True, True)
+    obs, info = env.reset(True, True)
     if not Training:
         env_renderer.set_new_rail()
 
diff --git a/torch_training/Multi_Agent_Training_Intro.md b/torch_training/Multi_Agent_Training_Intro.md
index d4eefae..69f89aa 100644
--- a/torch_training/Multi_Agent_Training_Intro.md
+++ b/torch_training/Multi_Agent_Training_Intro.md
@@ -174,7 +174,7 @@ We now use the normalized `agent_obs` for our training loop:
             agent_next_obs = [None] * env.get_num_agents()
 
         # Reset environment
-        obs = env.reset(True, True)
+        obs, info = env.reset(True, True)
 
         # Setup placeholder for finals observation of a single agent. This is necessary because agents terminate at
         # different times during an episode
diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py
index 580886b..b376623 100644
--- a/torch_training/multi_agent_inference.py
+++ b/torch_training/multi_agent_inference.py
@@ -3,16 +3,15 @@ from collections import deque
 
 import numpy as np
 import torch
-from importlib_resources import path
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
-
-import torch_training.Nets
 from flatland.envs.rail_env import RailEnv
-from flatland.envs.rail_generators import rail_from_file, sparse_rail_generator
-from flatland.envs.schedule_generators import schedule_from_file, sparse_schedule_generator
-
+from flatland.envs.rail_generators import sparse_rail_generator
+from flatland.envs.schedule_generators import sparse_schedule_generator
 from flatland.utils.rendertools import RenderTool
+from importlib_resources import path
+
+import torch_training.Nets
 from torch_training.dueling_double_dqn import Agent
 from utils.observation_utils import normalize_observation
 
@@ -97,7 +96,7 @@ frame_step = 0
 for trials in range(1, n_trials + 1):
 
     # Reset environment
-    obs = env.reset(True, True)
+    obs, info = env.reset(True, True)
 
     env_renderer.reset()
 
diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py
index e8ed93f..ec8ac96 100644
--- a/torch_training/multi_agent_training.py
+++ b/torch_training/multi_agent_training.py
@@ -162,7 +162,7 @@ def main(argv):
             agent_next_obs = [None] * env.get_num_agents()
 
         # Reset environment
-        obs = env.reset(True, True)
+        obs, info = env.reset(True, True)
 
         # Setup placeholder for finals observation of a single agent. This is necessary because agents terminate at
         # different times during an episode
diff --git a/torch_training/multi_agent_two_time_step_training.py b/torch_training/multi_agent_two_time_step_training.py
index d02e4b2..466ddf5 100644
--- a/torch_training/multi_agent_two_time_step_training.py
+++ b/torch_training/multi_agent_two_time_step_training.py
@@ -121,7 +121,7 @@ def main(argv):
             agent_next_obs = [None] * env.get_num_agents()
 
         # Reset environment
-        obs = env.reset(True, True)
+        obs, info = env.reset(True, True)
 
         # Setup placeholder for finals observation of a single agent. This is necessary because agents terminate at
         # different times during an episode
diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py
index 2649a23..969b7e9 100644
--- a/torch_training/render_agent_behavior.py
+++ b/torch_training/render_agent_behavior.py
@@ -3,15 +3,15 @@ from collections import deque
 
 import numpy as np
 import torch
-from importlib_resources import path
-
-import torch_training.Nets
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import sparse_rail_generator
 from flatland.envs.schedule_generators import sparse_schedule_generator
 from flatland.utils.rendertools import RenderTool
+from importlib_resources import path
+
+import torch_training.Nets
 from torch_training.dueling_double_dqn import Agent
 from utils.observation_utils import normalize_observation
 
@@ -111,7 +111,7 @@ frame_step = 0
 for trials in range(1, n_trials + 1):
 
     # Reset environment
-    obs = env.reset(True, True)
+    obs, info = env.reset(True, True)
     env_renderer.reset()
     # Build agent specific observations
     for a in range(env.get_num_agents()):
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 607206e..4f82c52 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -133,13 +133,11 @@ def main(argv):
             # Action
             for a in range(env.get_num_agents()):
                 if info['action_required'][a]:
-                    register_action_state[a] = True
                     action = agent.act(agent_obs[a], eps=eps)
                     action_prob[action] += 1
                     if step == 0:
                         agent_action_buffer[a] = action
                 else:
-                    register_action_state[a] = False
                     action = 0
                 action_dict.update({a: action})
 
@@ -151,24 +149,21 @@ def main(argv):
                 # Penalize waiting in order to get agent to move
                 if env.agents[a].status == 0:
                     all_rewards[a] -= 1
+
                 agent_next_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10)
                 cummulated_reward[a] += all_rewards[a]
 
             # Update replay buffer and train agent
             for a in range(env.get_num_agents()):
-                if done[a]:
-                    final_obs[a] = agent_obs_buffer[a]
-                    final_obs_next[a] = agent_next_obs[a].copy()
-                    final_action_dict.update({a: agent_action_buffer[a]})
-                if not done[a]:
-                    if agent_obs_buffer[a] is not None and register_action_state[a]:
-                        agent_delayed_next = agent_obs[a].copy()
-                        agent.step(agent_obs_buffer[a], agent_action_buffer[a], all_rewards[a],
-                                   agent_delayed_next, done[a])
-                        cummulated_reward[a] = 0.
-                    if register_action_state[a]:
-                        agent_obs_buffer[a] = agent_obs[a].copy()
-                        agent_action_buffer[a] = action_dict[a]
+                if (agent_obs_buffer[a] is not None and register_action_state[a] and env.agents[a].status != 3) or \
+                        env.agents[a].status == 2:
+                    agent_delayed_next = agent_obs[a].copy()
+                    agent.step(agent_obs_buffer[a], agent_action_buffer[a], all_rewards[a],
+                               agent_delayed_next, done[a])
+                    cummulated_reward[a] = 0.
+                if info['action_required'][a]:
+                    agent_obs_buffer[a] = agent_obs[a].copy()
+                    agent_action_buffer[a] = action_dict[a]
 
                 score += all_rewards[a] / env.get_num_agents()
 
@@ -176,8 +171,6 @@ def main(argv):
             agent_obs = agent_next_obs.copy()
             if done['__all__']:
                 env_done = 1
-                for a in range(env.get_num_agents()):
-                    agent.step(final_obs[a], final_action_dict[a], all_rewards[a], final_obs_next[a], done[a])
                 break
 
         # Epsilon decay
diff --git a/utils/misc_utils.py b/utils/misc_utils.py
index 4702c82..e4962ca 100644
--- a/utils/misc_utils.py
+++ b/utils/misc_utils.py
@@ -3,12 +3,12 @@ import time
 from collections import deque
 
 import numpy as np
-from line_profiler import LineProfiler
-
 from flatland.envs.observations import GlobalObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
 from flatland.envs.schedule_generators import complex_schedule_generator
+from line_profiler import LineProfiler
+
 from utils.observation_utils import norm_obs_clip, split_tree_into_feature_groups
 
 
@@ -102,7 +102,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
         # Reset the env
 
         lp_reset(True, True)
-        obs = env.reset(True, True)
+        obs, info = env.reset(True, True)
         for a in range(env.get_num_agents()):
             data, distance, agent_data = split_tree_into_feature_groups(obs[a], tree_depth)
             data = norm_obs_clip(data)
-- 
GitLab