diff --git a/examples/complex_rail_benchmark.py b/examples/complex_rail_benchmark.py index ecbbe8b4582fdb2047c52ac01c5aa3d9a330d30a..12f996fbb1d7232e5611358785b23a4c5a39b676 100644 --- a/examples/complex_rail_benchmark.py +++ b/examples/complex_rail_benchmark.py @@ -6,7 +6,7 @@ import numpy as np from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator from flatland.envs.schedule_generators import complex_schedule_generator - +from flatland.envs.observations import TreeObsForRailEnv def run_benchmark(): """Run benchmark on a small number of agents in complex rail environment.""" @@ -17,6 +17,7 @@ def run_benchmark(): env = RailEnv(width=15, height=15, rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12), schedule_generator=complex_schedule_generator(), + obs_builder_object=TreeObsForRailEnv(max_depth=2), number_of_agents=5) env.reset() @@ -42,9 +43,6 @@ def run_benchmark(): # Reset environment obs, info = env.reset() - for a in range(env.get_num_agents()): - norm = max(1, max_lt(obs[a], np.inf)) - obs[a] = np.clip(np.array(obs[a]) / norm, -1, 1) # Run episode for step in range(100): @@ -56,9 +54,6 @@ def run_benchmark(): # Environment step next_obs, all_rewards, done, _ = env.step(action_dict) - for a in range(env.get_num_agents()): - norm = max(1, max_lt(next_obs[a], np.inf)) - next_obs[a] = np.clip(np.array(next_obs[a]) / norm, -1, 1) if done['__all__']: break