diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index dfa336372ab2a1f3ad3078cee8effee209a53ddb..6bc6cdc4d41a9951b9fecfdd8871300d87466147 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -33,15 +33,15 @@ speed_ration_map = {1.: 0.25, # Fast passenger train env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(num_cities=9, # Number of cities in map (where train stations are) - min_node_dist=10, # Minimal distance of nodes + min_node_dist=12, # Minimal distance of nodes node_radius=4, # Proximity of stations to city center seed=0, # Random seed grid_mode=False, max_inter_city_rails=2, - tracks_in_city=4, + tracks_in_city=50, ), schedule_generator=sparse_schedule_generator(), - number_of_agents=100, + number_of_agents=50, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=GlobalObsForRailEnv()) diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 64e55347aad9f3d29e5ccaa1962e5c4bd391cffb..995e7d67ca87f306be7c87dbba685f46cde240d2 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -516,7 +516,6 @@ class GridTransitionMap(TransitionMap): double_slip = cells[5] three_way_transitions = [simple_switch_east_south, simple_switch_west_south] # loop over available outbound directions (indices) for rcPos - self.set_transitions(rcPos, 0) incoming_connections = np.zeros(4) for iDirOut in np.arange(4): @@ -539,19 +538,28 @@ class GridTransitionMap(TransitionMap): incoming_connections[iDirOut] = 1 number_of_incoming = np.sum(incoming_connections) - # Only one incoming direction --> Straight line + # Only one incoming direction --> Straight line set deadend if number_of_incoming == 1: - for direction in range(4): - if incoming_connections[direction] > 0: - self.set_transition((rcPos[0], rcPos[1], mirror(direction)), direction, 1) + if self.get_full_transitions(*rcPos) == 0: + self.set_transitions(rcPos, 0) + else: + self.set_transitions(rcPos, 0) + + for direction in range(4): + if incoming_connections[direction] > 0: + self.set_transition((rcPos[0], rcPos[1], mirror(direction)), direction, 1) # Connect all incoming connections if number_of_incoming == 2: + self.set_transitions(rcPos, 0) + connect_directions = np.argwhere(incoming_connections > 0) self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1) self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1) # Find feasible connection for three entries if number_of_incoming == 3: + self.set_transitions(rcPos, 0) + transition = np.random.choice(three_way_transitions, 1) hole = np.argwhere(incoming_connections < 1)[0][0] transition = transitions.rotate_transition(transition, int(hole * 90)) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 5704d26d2b94c691095b98b3d86b2865213e3771..a4d34a6bfe207f36077bedf3dec578e499cb703b 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -659,8 +659,8 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, def _generate_node_connection_points(node_positions, node_size, tracks_in_city=2): connection_points = [] connection_info = [] - if tracks_in_city > 2 * node_size + 1: - tracks_in_city = 2 * node_size + 1 + if tracks_in_city > 2 * node_size - 1: + tracks_in_city = 2 * node_size - 1 for node_position in node_positions: @@ -811,32 +811,6 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, train_stations[current_city].append(possible_location) return train_stations, built_num_trainstations - def _fix_transitions(grid_map): - """ - Function to fix all transition elements in environment - """ - # Fix all nodes with illegal transition maps - empty_to_fix = [] - rails_to_fix = [] - height, width = np.shape(grid_map.grid) - for r in range(height): - for c in range(width): - rc_pos = (r, c) - check = grid_map.cell_neighbours_valid(rc_pos, True) - if not check: - if grid_map.grid[rc_pos] == 0: - empty_to_fix.append(rc_pos) - else: - rails_to_fix.append(rc_pos) - - # Fix empty cells first to avoid cutting the network - for cell in empty_to_fix: - grid_map.fix_transitions(cell) - - # Fix all other cells - for cell in rails_to_fix: - grid_map.fix_transitions(cell) - def _generate_start_target_pairs(num_agents, nb_nodes, train_stations): # Generate start and target node directory for all agents. @@ -876,6 +850,32 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2, num_agents -= 1 return agent_start_targets_nodes, num_agents + def _fix_transitions(grid_map): + """ + Function to fix all transition elements in environment + """ + # Fix all nodes with illegal transition maps + empty_to_fix = [] + rails_to_fix = [] + height, width = np.shape(grid_map.grid) + for r in range(height): + for c in range(width): + rc_pos = (r, c) + check = grid_map.cell_neighbours_valid(rc_pos, True) + if not check: + if grid_map.grid[rc_pos] == 0: + empty_to_fix.append(rc_pos) + else: + rails_to_fix.append(rc_pos) + + # Fix empty cells first to avoid cutting the network + for cell in empty_to_fix: + grid_map.fix_transitions(cell) + + # Fix all other cells + for cell in rails_to_fix: + grid_map.fix_transitions(cell) + def _closest_neigh_in_direction(current_node, direction, node_positions): # Sort available neighbors according to their distance.