From dc950fd6875fe46ca71bad05f279d008fb535ffd Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Mon, 30 Sep 2019 20:14:46 -0400 Subject: [PATCH] fixed track numbering function --- flatland/envs/rail_generators.py | 29 +++++++++++++++++----------- flatland/envs/schedule_generators.py | 1 - 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index ac898d43..c2f4f438 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -768,18 +768,20 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, for tmp_out_connection_point in connection_points[current_node][direction]: # Find closest connection point min_connection_dist = np.inf - all_neighb_connection_points = [item for sublist in connection_points[neighb_idx] for item in - sublist] - - for tmp_in_connection_point in all_neighb_connection_points: - tmp_dist = distance_on_rail(tmp_out_connection_point, tmp_in_connection_point, + for dir in range(4): + current_points = connection_points[neighb_idx][dir] + for tmp_in_connection_point in current_points: + tmp_dist = distance_on_rail(tmp_out_connection_point, tmp_in_connection_point, metric="Manhattan") - if tmp_dist < min_connection_dist: - min_connection_dist = tmp_dist - neighb_connection_point = tmp_in_connection_point + if tmp_dist < min_connection_dist: + min_connection_dist = tmp_dist + neighb_connection_point = tmp_in_connection_point + neighbour_direction = dir new_line = connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point, city_cells) G.add_edge(current_node, neighb_idx, direction=direction, length=len(new_line)) + G.add_edge(neighb_idx, current_node, direction=neighbour_direction, length=len(new_line)) + all_paths.extend(new_line) direction += 1 @@ -1003,8 +1005,13 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, :return: """ if city_orientation % 2 == 0: - return np.abs(city_position[1] - position[1]) % 2 + if city_position[1] - position[1] < 0: + return np.abs(city_position[1] - position[1]) % 2 + else: + return (np.abs(city_position[1] - position[1]) + 1) % 2 else: - return np.abs(city_position[0] - position[0]) % 2 - + if city_position[0] - position[0] < 0: + return np.abs(city_position[0] - position[0]) % 2 + else: + return (np.abs(city_position[0] - position[0]) + 1) % 2 return generator diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index cd2f3996..3a378387 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -88,7 +88,6 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> track_to_use = 0 else: track_to_use = 1 - for i in range(len(train_stations[current_start_node])): if train_stations[current_start_node][i][1] == track_to_use: start_station_idx = i -- GitLab