diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 8b5a01cce5984dbc5316900a21f79dc57c474a30..e901da5bf9331a4041d267e8692ab408624a1a11 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -7,7 +7,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.schedule_generators import sparse_schedule_generator -from flatland.utils.rendertools import RenderTool +from flatland.utils.rendertools import RenderTool, AgentRenderVariant np.random.seed(1) @@ -33,20 +33,23 @@ speed_ration_map = {1.: 0.25, # Fast passenger train env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(num_cities=9, # Number of cities in map (where train stations are) - num_trainstations=45, # Number of possible start/targets on map min_node_dist=10, # Minimal distance of nodes node_radius=4, # Proximity of stations to city center - seed=15, # Random seed - grid_mode=False, + seed=0, # Random seed + grid_mode=True, max_connection_points_per_side=2, max_nr_connection_directions=2 ), schedule_generator=sparse_schedule_generator(), - number_of_agents=15, + number_of_agents=5, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=GlobalObsForRailEnv()) -env_renderer = RenderTool(env, gl="PILSVG", ) +env_renderer = RenderTool(env, gl="PILSVG", + agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX, + show_debug=True, + screen_height=1000, + screen_width=1000) # Import your own Agent or use RLlib to train agents on Flatland diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 5626c65aa00ef35fa6230f08794294cd01281aff..cf163e909941ee44b4da0af71c24534a7d1535d4 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -532,7 +532,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGener return generator -def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, node_radius=2, +def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, grid_mode=False, max_connection_points_per_side=4, max_nr_connection_directions=2, seed=0) -> RailGenerator: @@ -554,10 +554,6 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n def generator(width, height, num_agents, num_resets=0) -> RailGeneratorProduct: - if num_agents > num_trainstations: - num_agents = num_trainstations - warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents") - rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) rail_array = grid_map.grid @@ -592,8 +588,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n _build_inner_cities(node_positions, connection_points, rail_trans, grid_map) # Populate cities - train_stations, built_num_trainstation = _set_trainstation_positions(node_positions, city_cells, - num_trainstations, grid_map) + train_stations, built_num_trainstation = _set_trainstation_positions(node_positions, grid_map) # Adjust the number of agents if you could not build enough trainstations if num_agents > built_num_trainstation: @@ -603,8 +598,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n # Fix all transition elements _fix_transitions(grid_map) - # Generate start target paris - print(train_stations) + # Generate start target pairs agent_start_targets_nodes, num_agents = _generate_start_target_pairs(num_agents, nb_nodes, train_stations) return grid_map, {'agents_hints': { @@ -778,7 +772,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n return - def _set_trainstation_positions(node_positions, city_cells, num_trainstations, grid_map): + def _set_trainstation_positions(node_positions, grid_map): """ :param node_positions: @@ -787,9 +781,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n """ nb_nodes = len(node_positions) train_stations = [[] for i in range(nb_nodes)] - num_cities = len(node_positions) built_num_trainstations = 0 - stations_per_city = int(num_trainstations / num_cities) for current_city in range(len(node_positions)): for possible_location in _city_cells(node_positions[current_city], node_radius - 1): cell_type = grid_map.get_full_transitions(*possible_location) @@ -917,7 +909,6 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n city_boarder = [] for x in range(-radius, radius + 1): for y in range(-radius, radius + 1): - print(x, y, radius) if abs(x) == radius or abs(y) == radius: city_boarder.append((center[0] + x, center[1] + y)) return city_boarder diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index ae1cfefd2d7922f32255dc035ac055c537cc7b21..f41db2b0c453adfb578a6de489f780071287dc1b 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -57,6 +57,7 @@ def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator: + def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None): train_stations = hints['train_stations'] agent_start_targets_nodes = hints['agent_start_targets_nodes'] @@ -102,7 +103,7 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> # Orient the agent correctly for orientation in range(4): transitions = rail.get_transitions(start[0], start[1], orientation) - if any(transitions) > 0: + if any(transitions) > 0 and rail.check_path_exists(start, orientation, target): agents_direction.append(orientation) break