From 474d13e494379368936af55b6a1006670e9bfa65 Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Tue, 20 Aug 2019 16:11:34 -0400 Subject: [PATCH] removed constraint on start location --- flatland/core/transition_map.py | 1 - flatland/envs/generators.py | 22 +++++++++++++++++++--- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 048593d9..2aff675d 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -437,7 +437,6 @@ class GridTransitionMap(TransitionMap): number_of_incoming = np.sum(incomping_connections) # Only one incoming direction --> Straight line if number_of_incoming == 1: - for direction in range(4): if incomping_connections[direction] > 0: self.set_transition((rcPos[0], rcPos[1], mirror(direction)), direction, 1) diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index bf9679f9..37779d22 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -1032,6 +1032,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation connect_nodes(rail_trans, rail_array, node_positions[current_node], node_positions[neighb]) node_stack.pop(0) + # Place train stations close to the node # We currently place them uniformly distirbuted among all cities if num_cities > 1: @@ -1046,6 +1047,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation station_y = np.clip(node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius), 0, width - 1) + tries = 0 while (station_x, station_y) in train_stations or (station_x, station_y) == node_positions[ trainstation_node] or rail_array[(station_x, station_y)] != 0: station_x = np.clip( @@ -1056,6 +1058,10 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius), 0, width - 1) + tries += 1 + if tries > 100: + warnings.warn("Could not set trainstations, please change initial parameters!!!!") + break train_stations[trainstation_node].append((station_x, station_y)) # Connect train station to the correct node @@ -1122,8 +1128,10 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation tries += 1 # Test again with new start node if no pair is found (This code needs to be improved) if tries > 10: - break start_node = np.random.choice(avail_start_nodes) + if tries > 100: + warnings.warn("Could not set trainstations, please change initial parameters!!!!") + break node_available_start[start_node] -= 1 node_available_target[target_node] -= 1 @@ -1141,18 +1149,26 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation current_target_node = agent_start_targets_nodes[agent_idx][1] target_station_idx = np.random.randint(len(train_stations[current_target_node])) target = train_stations[current_target_node][target_station_idx] + tries = 0 while (target[0], target[1]) in agents_target: target_station_idx = np.random.randint(len(train_stations[current_target_node])) target = train_stations[current_target_node][target_station_idx] - + tries += 1 + if tries > 100: + warnings.warn("Could not set target position, please change initial parameters!!!!") + break agents_target.append((target[0], target[1])) # Set start for agent current_start_node = agent_start_targets_nodes[agent_idx][0] start_station_idx = np.random.randint(len(train_stations[current_start_node])) start = train_stations[current_start_node][start_station_idx] - + tries = 0 while (start[0], start[1]) in agents_position: + tries += 1 + if tries > 100: + warnings.warn("Could not set start position, please change initial parameters!!!!") + break start_station_idx = np.random.randint(len(train_stations[current_start_node])) start = train_stations[current_start_node][start_station_idx] -- GitLab