From b911088dd82413f138c3fd4c050a7fb60cd77553 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Tue, 3 Nov 2020 17:04:23 +0100
Subject: [PATCH] FastTreeObs working

---
 .../multi_agent_training.py                   | 30 +++++++++++--------
 reinforcement_learning/policy.py              |  6 ++++
 run.py                                        |  2 ++
 utils/fast_tree_obs.py                        |  2 +-
 4 files changed, 26 insertions(+), 14 deletions(-)

diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index 9c78a72..fd85fdc 100755
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -1,23 +1,22 @@
-from datetime import datetime
 import os
 import random
 import sys
 from argparse import ArgumentParser, Namespace
+from collections import deque
+from datetime import datetime
 from pathlib import Path
 from pprint import pprint
-import psutil
-from flatland.utils.rendertools import RenderTool
-from torch.utils.tensorboard import SummaryWriter
-import numpy as np
-import torch
-from collections import deque
 
+import numpy as np
+import psutil
+from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv, RailEnvActions
 from flatland.envs.rail_generators import sparse_rail_generator
 from flatland.envs.schedule_generators import sparse_schedule_generator
-from flatland.envs.observations import TreeObsForRailEnv
-from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
-from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.utils.rendertools import RenderTool
+from torch.utils.tensorboard import SummaryWriter
 
 base_dir = Path(__file__).resolve().parent.parent
 sys.path.append(str(base_dir))
@@ -250,6 +249,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
         # Run episode
         for step in range(max_steps - 1):
             inference_timer.start()
+            policy.start_step()
             for agent in train_env.get_agent_handles():
                 if info['action_required'][agent]:
                     update_values[agent] = True
@@ -264,6 +264,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
                     update_values[agent] = False
                     action = 0
                 action_dict.update({agent: action})
+            policy.end_step()
             inference_timer.end()
 
             # Environment step
@@ -285,7 +286,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
                 if update_values[agent] or done['__all__']:
                     # Only learn from timesteps where somethings happened
                     learn_timer.start()
-                    policy.step(agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], agent_obs[agent], done[agent])
+                    policy.step(agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], agent_obs[agent],
+                                done[agent])
                     learn_timer.end()
 
                     agent_prev_obs[agent] = agent_obs[agent].copy()
@@ -434,6 +436,7 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
         final_step = 0
 
         for step in range(max_steps - 1):
+            policy.start_step()
             for agent in env.get_agent_handles():
                 if tree_observation.check_is_observation_valid(agent_obs[agent]):
                     agent_obs[agent] = tree_observation.get_normalized_observation(obs[agent], tree_depth=tree_depth,
@@ -444,7 +447,7 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
                     if tree_observation.check_is_observation_valid(agent_obs[agent]):
                         action = policy.act(agent_obs[agent], eps=0.0)
                 action_dict.update({agent: action})
-
+            policy.end_step()
             obs, all_rewards, done, info = env.step(action_dict)
 
             for agent in env.get_agent_handles():
@@ -495,7 +498,8 @@ if __name__ == "__main__":
     parser.add_argument("--num_threads", help="number of threads PyTorch can use", default=1, type=int)
     parser.add_argument("--render", help="render 1 episode in 100", action='store_true')
     parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str)
-    parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs", action='store_true')
+    parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs",
+                        action='store_true')
     parser.add_argument("--max_depth", help="max depth", default=2, type=int)
 
     training_params = parser.parse_args()
diff --git a/reinforcement_learning/policy.py b/reinforcement_learning/policy.py
index d5c6a3c..b8714d1 100644
--- a/reinforcement_learning/policy.py
+++ b/reinforcement_learning/policy.py
@@ -10,3 +10,9 @@ class Policy:
 
     def load(self, filename):
         raise NotImplementedError
+
+    def start_step(self):
+        pass
+
+    def end_step(self):
+        pass
diff --git a/run.py b/run.py
index f11c941..09dc04d 100644
--- a/run.py
+++ b/run.py
@@ -114,6 +114,7 @@ while True:
             if not check_if_all_blocked(env=local_env):
                 time_start = time.time()
                 action_dict = {}
+                policy.start_step()
                 for agent in range(nb_agents):
                     if info['action_required'][agent]:
                         if agent in agent_last_obs and np.all(agent_last_obs[agent] == observation[agent]):
@@ -128,6 +129,7 @@ while True:
                         if USE_ACTION_CACHE:
                             agent_last_obs[agent] = observation[agent]
                             agent_last_action[agent] = action
+                policy.end_step()
                 agent_time = time.time() - time_start
                 time_taken_by_controller.append(agent_time)
 
diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py
index 7e4c934..8fe6caf 100755
--- a/utils/fast_tree_obs.py
+++ b/utils/fast_tree_obs.py
@@ -21,7 +21,7 @@ class FastTreeObs(ObservationBuilder):
 
     def __init__(self, max_depth):
         self.max_depth = max_depth
-        self.observation_dim = 26
+        self.observation_dim = 27
 
     def build_data(self):
         if self.env is not None:
-- 
GitLab