From 5e89c7d28768ffdecdbabbeea97d57bb8ff80c19 Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Mon, 28 Oct 2019 16:48:31 -0400 Subject: [PATCH] fixed all examples in example folder --- examples/complex_rail_benchmark.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/examples/complex_rail_benchmark.py b/examples/complex_rail_benchmark.py index ecbbe8b4..12f996fb 100644 --- a/examples/complex_rail_benchmark.py +++ b/examples/complex_rail_benchmark.py @@ -6,7 +6,7 @@ import numpy as np from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator from flatland.envs.schedule_generators import complex_schedule_generator - +from flatland.envs.observations import TreeObsForRailEnv def run_benchmark(): """Run benchmark on a small number of agents in complex rail environment.""" @@ -17,6 +17,7 @@ def run_benchmark(): env = RailEnv(width=15, height=15, rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12), schedule_generator=complex_schedule_generator(), + obs_builder_object=TreeObsForRailEnv(max_depth=2), number_of_agents=5) env.reset() @@ -42,9 +43,6 @@ def run_benchmark(): # Reset environment obs, info = env.reset() - for a in range(env.get_num_agents()): - norm = max(1, max_lt(obs[a], np.inf)) - obs[a] = np.clip(np.array(obs[a]) / norm, -1, 1) # Run episode for step in range(100): @@ -56,9 +54,6 @@ def run_benchmark(): # Environment step next_obs, all_rewards, done, _ = env.step(action_dict) - for a in range(env.get_num_agents()): - norm = max(1, max_lt(next_obs[a], np.inf)) - next_obs[a] = np.clip(np.array(next_obs[a]) / norm, -1, 1) if done['__all__']: break -- GitLab