diff --git a/flatland/core/grid/grid4_astar.py b/flatland/core/grid/grid4_astar.py index feb72313f21b9ecc989688d63ba02ccf3a458107..a55b96045fc0f3726d46da66672494aa8c52b182 100644 --- a/flatland/core/grid/grid4_astar.py +++ b/flatland/core/grid/grid4_astar.py @@ -103,4 +103,5 @@ def a_star(rail_trans, rail_array, start, end): # no full path found if len(open_nodes) == 0: + print("could not make path") return [] diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 018c8cd5b326a9246c393121153d6f5478a12f8b..dbb68a739cc50d94d83bd5c2d1127b36c982df5c 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -403,6 +403,62 @@ class GridTransitionMap(TransitionMap): return True + def fix_transitions(self, rcPos): + """ + Fixes broken transitions + """ + gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc] + grcPos = array(rcPos) + grcMax = self.grid.shape + + # loop over available outbound directions (indices) for rcPos + incomping_connections = np.zeros(4) + for iDirOut in np.arange(4): + gdRC = gDir2dRC[iDirOut] # row,col increment + gPos2 = grcPos + gdRC # next cell in that direction + + # Check the adjacent cell is within bounds + # if not, then this transition is invalid! + if np.any(gPos2 < 0): + return False + if np.any(gPos2 >= grcMax): + return False + + # Get the transitions out of gPos2, using iDirOut as the inbound direction + # if there are no available transitions, ie (0,0,0,0), then rcPos is invalid + connected = 0 + for orientation in range(4): + connected += self.get_transition((gPos2[0], gPos2[1], orientation), mirror(iDirOut)) + if connected > 0: + incomping_connections[iDirOut] = 1 + + number_of_incoming = np.sum(incomping_connections) + # Only one incoming direction --> Straight line + if number_of_incoming == 1: + + for direction in range(4): + if incomping_connections[direction] > 0: + self.set_transition((rcPos[0], rcPos[1], mirror(direction)), direction, 1) + # Connect all incoming connections + if number_of_incoming == 2: + connect_directions = np.argwhere(incomping_connections > 0) + self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1) + self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1) + + # Find feasible connection fro three entries + if number_of_incoming == 3: + hole = np.argwhere(incomping_connections < 1)[0][0] + connect_directions = [(hole + 1) % 4, (hole + 2) % 4, (hole + 3) % 4] + self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1) + self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[2], 1) + self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1) + self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[2])), connect_directions[0], 1) + # Make a cross + if number_of_incoming == 4: + for direction in range(4): + self.set_transition((grcPos[0], grcPos[1], direction), direction, 1) + return True + def mirror(dir): return (dir + 2) % 4 diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 0c1b7d15406e054959b1af35e67e1060c1c3e47b..b984be6f1cce69cb2d6320a52a602de2cc88d0da 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -9,7 +9,7 @@ from flatland.core.grid.grid_utils import distance_on_rail from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgentStatic -from flatland.envs.grid4_generators_utils import connect_rail, connect_from_nodes, connect_nodes +from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes, connect_from_nodes from flatland.envs.grid4_generators_utils import get_rnd_agents_pos_tgt_dir_on_rail @@ -863,7 +863,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation if tries > 100: warnings.warn("Could not set nodes, please change initial parameters!!!!") break - + print(node_positions) # Chose node connection available_nodes_full = np.arange(num_cities + num_intersections) available_cities = np.arange(num_cities) @@ -881,7 +881,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation delete_idx = np.where(available_cities == current_node) available_cities = np.delete(available_cities, delete_idx, 0) - elif len(available_intersections) > 0: + elif len(available_intersections) > 0 and len(available_cities) > 0: available_nodes = available_cities delete_idx = np.where(available_intersections == current_node) available_intersections = np.delete(available_intersections, delete_idx, 0) @@ -894,8 +894,6 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation available_nodes = available_nodes[np.argsort(node_dist)] # Set number of neighboring nodes - - print(current_node, allowed_connections) if len(available_nodes) >= allowed_connections: connected_neighb_idx = available_nodes[:allowed_connections] else: @@ -903,6 +901,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation if current_node == 0: allowed_connections -= 1 + # Connect to the neighboring nodes for neighb in connected_neighb_idx: if neighb not in node_stack: @@ -938,12 +937,14 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation train_stations[trainstation_node].append((station_x, station_y)) # Connect train station to the correct node - connect_from_nodes(rail_trans, rail_array, node_positions[trainstation_node], (station_x, station_y)) - + connection = connect_from_nodes(rail_trans, rail_array, node_positions[trainstation_node], + (station_x, station_y)) + # Check if connection was made + if len(connection) == 0: + train_stations[trainstation_node].pop(-1) # Fix all nodes with illegal transition maps for current_node in node_positions: - if not grid_map.cell_neighbours_valid(current_node): - grid_map.fix_neighbours(current_node) + grid_map.fix_transitions(current_node) # Generate start and target node directory for all agents. # Assure that start and target are not in the same node diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index fbb1f03df3913b1a56343e8912e698bf02094090..bdf49f87d4c33d899e5e0c53b7ee50df190a8c7f 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -76,8 +76,8 @@ def connect_nodes(rail_trans, rail_array, start, end): if index == 0: if new_trans == 0: # end-point - # need to flip direction because of how end points are defined - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + # don't set any transition at node yet + new_trans = 0 else: # into existing rail new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) @@ -93,7 +93,9 @@ def connect_nodes(rail_trans, rail_array, start, end): new_trans_e = rail_array[end_pos] if new_trans_e == 0: # end-point - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) + # don't set any transition at node yet + + new_trans_e = 0 else: # into existing rail new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py index 1b274bcae21bcd89cc86f5d10b4ae2603ff915b4..77b4c4af435f753396e10ceaba7c0b4b3ca8b2b3 100644 --- a/tests/test_flatland_env_sparse_rail_generator.py +++ b/tests/test_flatland_env_sparse_rail_generator.py @@ -1,5 +1,3 @@ -import time - import numpy as np from flatland.envs.generators import sparse_rail_generator, realistic_rail_generator @@ -23,11 +21,11 @@ def test_realistic_rail_generator(): def test_sparse_rail_generator(): - env = RailEnv(width=50, - height=50, - rail_generator=sparse_rail_generator(num_cities=3, # Number of cities in map + env = RailEnv(width=20, + height=20, + rail_generator=sparse_rail_generator(num_cities=5, # Number of cities in map num_intersections=2, # Number of interesections in map - num_trainstations=15, # Number of possible start/targets on map + num_trainstations=20, # Number of possible start/targets on map min_node_dist=6, # Minimal distance of nodes node_radius=3, # Proximity of stations to city center num_neighb=2, # Number of connections to other cities @@ -38,4 +36,3 @@ def test_sparse_rail_generator(): # reset to initialize agents_static env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer.render_env(show=True, show_observations=True, show_predictions=False) - time.sleep(10)