Skip to content
Snippets Groups Projects
Commit 76b9f2cf authored by Erik Nygren's avatar Erik Nygren
Browse files

added in line profiler utility for benchmarking env

parent ce36ddc4
No related branches found
No related tags found
No related merge requests found
{'Test_0':[10,10,1,3],
{'Test_0':[100,100,5,3],
'Test_1':[10,10,3,4321],
'Test_2':[10,10,5,123],
'Test_3':[50,50,5,21],
......
......@@ -7,6 +7,7 @@ from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from line_profiler import LineProfiler
from utils.observation_utils import norm_obs_clip, split_tree
......@@ -64,6 +65,7 @@ class RandomAgent:
def run_test(parameters, agent, test_nr=0, tree_depth=3):
# Parameter initialization
lp = LineProfiler()
features_per_node = 9
start_time_scoring = time.time()
action_dict = dict()
......@@ -90,10 +92,15 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
predictor=ShortestPathPredictorForRailEnv()),
number_of_agents=parameters[2])
max_steps = int(3 * (env.height + env.width))
lp_step = lp(env.step)
lp_reset = lp(env.reset)
agent_obs = [None] * env.get_num_agents()
printProgressBar(0, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20)
for trial in range(nr_trials_per_test):
# Reset the env
lp_reset(True, True)
obs = env.reset(True, True)
for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=9,
......@@ -118,7 +125,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
action_dict.update({a: action})
# Environment step
next_obs, all_rewards, done, _ = env.step(action_dict)
next_obs, all_rewards, done, _ = lp_step(action_dict)
for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(next_obs[a]),
......@@ -140,4 +147,5 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
printProgressBar(trial + 1, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20)
end_time_scoring = time.time()
tot_test_time = end_time_scoring - start_time_scoring
lp.print_stats()
return test_scores, test_dones, tot_test_time
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment