Skip to content
Snippets Groups Projects
Commit 76b9f2cf authored by Erik Nygren's avatar Erik Nygren
Browse files

added in line profiler utility for benchmarking env

parent ce36ddc4
No related branches found
No related tags found
No related merge requests found
{'Test_0':[10,10,1,3], {'Test_0':[100,100,5,3],
'Test_1':[10,10,3,4321], 'Test_1':[10,10,3,4321],
'Test_2':[10,10,5,123], 'Test_2':[10,10,5,123],
'Test_3':[50,50,5,21], 'Test_3':[50,50,5,21],
......
...@@ -7,6 +7,7 @@ from flatland.envs.generators import complex_rail_generator ...@@ -7,6 +7,7 @@ from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from line_profiler import LineProfiler
from utils.observation_utils import norm_obs_clip, split_tree from utils.observation_utils import norm_obs_clip, split_tree
...@@ -64,6 +65,7 @@ class RandomAgent: ...@@ -64,6 +65,7 @@ class RandomAgent:
def run_test(parameters, agent, test_nr=0, tree_depth=3): def run_test(parameters, agent, test_nr=0, tree_depth=3):
# Parameter initialization # Parameter initialization
lp = LineProfiler()
features_per_node = 9 features_per_node = 9
start_time_scoring = time.time() start_time_scoring = time.time()
action_dict = dict() action_dict = dict()
...@@ -90,10 +92,15 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3): ...@@ -90,10 +92,15 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
predictor=ShortestPathPredictorForRailEnv()), predictor=ShortestPathPredictorForRailEnv()),
number_of_agents=parameters[2]) number_of_agents=parameters[2])
max_steps = int(3 * (env.height + env.width)) max_steps = int(3 * (env.height + env.width))
lp_step = lp(env.step)
lp_reset = lp(env.reset)
agent_obs = [None] * env.get_num_agents() agent_obs = [None] * env.get_num_agents()
printProgressBar(0, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20) printProgressBar(0, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20)
for trial in range(nr_trials_per_test): for trial in range(nr_trials_per_test):
# Reset the env # Reset the env
lp_reset(True, True)
obs = env.reset(True, True) obs = env.reset(True, True)
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=9, data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=9,
...@@ -118,7 +125,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3): ...@@ -118,7 +125,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
action_dict.update({a: action}) action_dict.update({a: action})
# Environment step # Environment step
next_obs, all_rewards, done, _ = env.step(action_dict) next_obs, all_rewards, done, _ = lp_step(action_dict)
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), data, distance, agent_data = split_tree(tree=np.array(next_obs[a]),
...@@ -140,4 +147,5 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3): ...@@ -140,4 +147,5 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
printProgressBar(trial + 1, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20) printProgressBar(trial + 1, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20)
end_time_scoring = time.time() end_time_scoring = time.time()
tot_test_time = end_time_scoring - start_time_scoring tot_test_time = end_time_scoring - start_time_scoring
lp.print_stats()
return test_scores, test_dones, tot_test_time return test_scores, test_dones, tot_test_time
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment