From 76b9f2cf86da51862722a9b1b9972df4af3a2e6a Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Sat, 6 Jul 2019 11:01:35 -0400
Subject: [PATCH] added in line profiler utility for benchmarking env

---
 parameters.txt      |  2 +-
 utils/misc_utils.py | 10 +++++++++-
 2 files changed, 10 insertions(+), 2 deletions(-)

diff --git a/parameters.txt b/parameters.txt
index aa83e8d..80ad8c2 100644
--- a/parameters.txt
+++ b/parameters.txt
@@ -1,4 +1,4 @@
-{'Test_0':[10,10,1,3],
+{'Test_0':[100,100,5,3],
 'Test_1':[10,10,3,4321],
 'Test_2':[10,10,5,123],
 'Test_3':[50,50,5,21],
diff --git a/utils/misc_utils.py b/utils/misc_utils.py
index 097450b..d4c6ef8 100644
--- a/utils/misc_utils.py
+++ b/utils/misc_utils.py
@@ -7,6 +7,7 @@ from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
+from line_profiler import LineProfiler
 
 from utils.observation_utils import norm_obs_clip, split_tree
 
@@ -64,6 +65,7 @@ class RandomAgent:
 
 def run_test(parameters, agent, test_nr=0, tree_depth=3):
     # Parameter initialization
+    lp = LineProfiler()
     features_per_node = 9
     start_time_scoring = time.time()
     action_dict = dict()
@@ -90,10 +92,15 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
                                                        predictor=ShortestPathPredictorForRailEnv()),
                   number_of_agents=parameters[2])
     max_steps = int(3 * (env.height + env.width))
+    lp_step = lp(env.step)
+    lp_reset = lp(env.reset)
+
     agent_obs = [None] * env.get_num_agents()
     printProgressBar(0, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20)
     for trial in range(nr_trials_per_test):
         # Reset the env
+
+        lp_reset(True, True)
         obs = env.reset(True, True)
         for a in range(env.get_num_agents()):
             data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=9,
@@ -118,7 +125,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
                 action_dict.update({a: action})
 
             # Environment step
-            next_obs, all_rewards, done, _ = env.step(action_dict)
+            next_obs, all_rewards, done, _ = lp_step(action_dict)
 
             for a in range(env.get_num_agents()):
                 data, distance, agent_data = split_tree(tree=np.array(next_obs[a]),
@@ -140,4 +147,5 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
         printProgressBar(trial + 1, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20)
     end_time_scoring = time.time()
     tot_test_time = end_time_scoring - start_time_scoring
+    lp.print_stats()
     return test_scores, test_dones, tot_test_time
-- 
GitLab