From f17ec7e06e307a2fe69b53a4f131fa6ecbe327b3 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Sun, 1 Sep 2019 07:35:06 -0400 Subject: [PATCH] minor bugfix in level generator --- examples/flatland_2_0_example.py | 15 +++++++++------ flatland/envs/rail_generators.py | 6 ++++-- flatland/envs/schedule_generators.py | 4 ++-- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index a99009a5..1a18cb40 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -1,3 +1,5 @@ +import time + import numpy as np from flatland.envs.observations import TreeObsForRailEnv @@ -30,18 +32,18 @@ speed_ration_map = {1.: 0.25, # Fast passenger train env = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(num_cities=20, # Number of cities in map (where train stations are) - num_intersections=5, # Number of intersections (no start / target) - num_trainstations=15, # Number of possible start/targets on map + rail_generator=sparse_rail_generator(num_cities=25, # Number of cities in map (where train stations are) + num_intersections=0, # Number of intersections (no start / target) + num_trainstations=0, # Number of possible start/targets on map min_node_dist=3, # Minimal distance of nodes node_radius=2, # Proximity of stations to city center - num_neighb=4, # Number of connections to other cities/intersections + num_neighb=3, # Number of connections to other cities/intersections seed=15, # Random seed realistic_mode=True, - enhance_intersection=True + enhance_intersection=False ), schedule_generator=sparse_schedule_generator(speed_ration_map), - number_of_agents=10, + number_of_agents=0, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=TreeObservation) @@ -112,6 +114,7 @@ for step in range(500): next_obs, all_rewards, done, _ = env.step(action_dict) env_renderer.render_env(show=True, show_observations=False, show_predictions=False) frame_step += 1 + time.sleep(10.1) # Update replay buffer and train agent for a in range(env.get_num_agents()): agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index b69d80a0..93e7ce55 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -675,9 +675,11 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 # Place train stations close to the node # We currently place them uniformly distirbuted among all cities + built_num_trainstation = 0 + train_stations = [[] for i in range(num_cities)] + if num_cities > 1: - train_stations = [[] for i in range(num_cities)] - built_num_trainstation = 0 + for station in range(num_trainstations): spot_found = True trainstation_node = int(station / num_trainstations * num_cities) diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index a3a6dc1e..7a116f9e 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -110,7 +110,7 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> else: speeds = [1.0] * len(agents_position) - return agents_position, agents_direction, agents_target, speeds + return agents_position, agents_direction, agents_target, speeds, None return generator @@ -203,7 +203,7 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> np.random.choice(len(valid_starting_directions), 1)[0]] agents_speed = speed_initialization_helper(num_agents, speed_ratio_map) - return agents_position, agents_direction, agents_target, agents_speed + return agents_position, agents_direction, agents_target, agents_speed, None return generator -- GitLab