From 76b9f2cf86da51862722a9b1b9972df4af3a2e6a Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Sat, 6 Jul 2019 11:01:35 -0400 Subject: [PATCH] added in line profiler utility for benchmarking env --- parameters.txt | 2 +- utils/misc_utils.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/parameters.txt b/parameters.txt index aa83e8d..80ad8c2 100644 --- a/parameters.txt +++ b/parameters.txt @@ -1,4 +1,4 @@ -{'Test_0':[10,10,1,3], +{'Test_0':[100,100,5,3], 'Test_1':[10,10,3,4321], 'Test_2':[10,10,5,123], 'Test_3':[50,50,5,21], diff --git a/utils/misc_utils.py b/utils/misc_utils.py index 097450b..d4c6ef8 100644 --- a/utils/misc_utils.py +++ b/utils/misc_utils.py @@ -7,6 +7,7 @@ from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv +from line_profiler import LineProfiler from utils.observation_utils import norm_obs_clip, split_tree @@ -64,6 +65,7 @@ class RandomAgent: def run_test(parameters, agent, test_nr=0, tree_depth=3): # Parameter initialization + lp = LineProfiler() features_per_node = 9 start_time_scoring = time.time() action_dict = dict() @@ -90,10 +92,15 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3): predictor=ShortestPathPredictorForRailEnv()), number_of_agents=parameters[2]) max_steps = int(3 * (env.height + env.width)) + lp_step = lp(env.step) + lp_reset = lp(env.reset) + agent_obs = [None] * env.get_num_agents() printProgressBar(0, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20) for trial in range(nr_trials_per_test): # Reset the env + + lp_reset(True, True) obs = env.reset(True, True) for a in range(env.get_num_agents()): data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=9, @@ -118,7 +125,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3): action_dict.update({a: action}) # Environment step - next_obs, all_rewards, done, _ = env.step(action_dict) + next_obs, all_rewards, done, _ = lp_step(action_dict) for a in range(env.get_num_agents()): data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), @@ -140,4 +147,5 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3): printProgressBar(trial + 1, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20) end_time_scoring = time.time() tot_test_time = end_time_scoring - start_time_scoring + lp.print_stats() return test_scores, test_dones, tot_test_time -- GitLab