diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index a4e24f9d73192407d671294848603e0a880df4f0..45e823b78374d8a0b4ac130a8c707e874177ddda 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -6,7 +6,7 @@ from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import sparse_rail_generator -from flatland.envs.schedule_generators import random_schedule_generator +from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.utils.rendertools import RenderTool np.random.seed(1) @@ -34,16 +34,16 @@ env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(num_cities=9, # Number of cities in map (where train stations are) num_intersections=0, # Number of intersections (no start / target) - num_trainstations=50, # Number of possible start/targets on map + num_trainstations=100, # Number of possible start/targets on map min_node_dist=3, # Minimal distance of nodes - node_radius=5, # Proximity of stations to city center + node_radius=4, # Proximity of stations to city center num_neighb=3, # Number of connections to other cities/intersections seed=15, # Random seed grid_mode=True, enhance_intersection=False ), - schedule_generator=random_schedule_generator(), - number_of_agents=0, + schedule_generator=sparse_schedule_generator(), + number_of_agents=50, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=GlobalObsForRailEnv()) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index d1acc464efdc72962bfd94403d5de7156c76d838..a780cfc9b260672e34d5e9c19d4b5cd57bd4db5e 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -600,7 +600,6 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 # Set up connection points connection_points = _generate_node_connection_points(node_positions, node_radius, max_nr_connection_points=4) - print(connection_points) # Start at some node current_node = np.random.randint(len(available_nodes_full)) node_stack = [current_node] @@ -654,7 +653,6 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 for tmp_out_connection_point in connection_points[current_node]: tmp_dist_to_node = distance_on_rail(tmp_out_connection_point, node_positions[neighb]) # Check if this connection node is on the city side facing the neighbour - print("Current node", current_node, "Neigh", neighb, "Distance", tmp_dist_to_node, dist_from_center) if tmp_dist_to_node < dist_from_center - 1: min_connection_dist = np.inf @@ -714,7 +712,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 # Connect train station to random nodes - rand_corner_nodes = np.random.choice(range(len(connection_points[trainstation_node])), 2, replace=False) + rand_corner_nodes = np.random.choice(range(len(connection_points[trainstation_node])), 3, replace=False) for corner_node_idx in rand_corner_nodes: connection = connect_nodes(rail_trans, grid_map, connection_points[trainstation_node][corner_node_idx], @@ -776,7 +774,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 return grid_map, {'agents_hints': { 'num_agents': num_agents, 'agent_start_targets_nodes': agent_start_targets_nodes, - 'train_stations': train_stations_slots + 'train_stations': train_stations }} def _generate_node_positions_not_grid_mode(city_positions, height, intersection_positions, nb_nodes, @@ -788,8 +786,8 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 tries = 0 while to_close: - x_tmp = node_radius + np.random.randint(height - node_radius) - y_tmp = node_radius + np.random.randint(width - node_radius) + x_tmp = node_radius + np.random.randint(height - node_radius - 1) + y_tmp = node_radius + np.random.randint(width - node_radius - 1) to_close = False # Check distance to cities @@ -841,7 +839,6 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 connection_per_direction = n_connection_points // 4 connection_point_vector = [connection_per_direction, connection_per_direction, connection_per_direction, n_connection_points - 3 * connection_per_direction] - print(connection_point_vector) connection_points_coordinates = [] for direction in range(4): rnd_points = np.random.choice(np.arange(-node_size, node_size), size=connection_point_vector[direction],