Skip to content
Snippets Groups Projects
Commit 5e89c7d2 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

fixed all examples in example folder

parent c277314c
No related branches found
No related tags found
No related merge requests found
......@@ -6,7 +6,7 @@ import numpy as np
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.envs.observations import TreeObsForRailEnv
def run_benchmark():
"""Run benchmark on a small number of agents in complex rail environment."""
......@@ -17,6 +17,7 @@ def run_benchmark():
env = RailEnv(width=15, height=15,
rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12),
schedule_generator=complex_schedule_generator(),
obs_builder_object=TreeObsForRailEnv(max_depth=2),
number_of_agents=5)
env.reset()
......@@ -42,9 +43,6 @@ def run_benchmark():
# Reset environment
obs, info = env.reset()
for a in range(env.get_num_agents()):
norm = max(1, max_lt(obs[a], np.inf))
obs[a] = np.clip(np.array(obs[a]) / norm, -1, 1)
# Run episode
for step in range(100):
......@@ -56,9 +54,6 @@ def run_benchmark():
# Environment step
next_obs, all_rewards, done, _ = env.step(action_dict)
for a in range(env.get_num_agents()):
norm = max(1, max_lt(next_obs[a], np.inf))
next_obs[a] = np.clip(np.array(next_obs[a]) / norm, -1, 1)
if done['__all__']:
break
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment