From ef5ad8e95f540f5d92078715da9e699070ecc197 Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Wed, 25 Sep 2019 18:17:26 -0400 Subject: [PATCH] updated tests to reflect changes to sparse_rail_generator --- flatland/envs/rail_generators.py | 13 ++++++------- tests/test_flatland_envs_sparse_rail_generator.py | 3 +-- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 266d2bec..de780402 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -628,8 +628,6 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 else: connected_neighb_idx = available_nodes - print(current_node, connected_neighb_idx) - # Connect to the neighboring nodes for neighb in connected_neighb_idx: if neighb not in node_stack and neighb in open_nodes: @@ -660,13 +658,13 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 # Place train stations close to the node # We currently place them uniformly distributed among all cities built_num_trainstation = 0 - train_stations = [[] for i in range(_num_cities)] - if _num_cities > 1: + train_stations = [[] for i in range(nb_nodes)] + if nb_nodes > 1: for station in range(num_trainstations): spot_found = True reduced_node_radius = node_radius - 1 - trainstation_node = int(station / num_trainstations * _num_cities) + trainstation_node = int(station / num_trainstations * nb_nodes) station_x = np.clip( node_positions[trainstation_node][0] + np.random.randint(-reduced_node_radius, reduced_node_radius), @@ -778,12 +776,12 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 # Slot availability in node node_available_start = [] node_available_target = [] - for node_idx in range(_num_cities): + for node_idx in range(nb_nodes): node_available_start.append(len(train_stations[node_idx])) node_available_target.append(len(train_stations[node_idx])) # Assign agents to slots - for agent_idx in range(num_agents): + for agent_idx in range(nb_nodes): avail_start_nodes = [idx for idx, val in enumerate(node_available_start) if val > 0] avail_target_nodes = [idx for idx, val in enumerate(node_available_target) if val > 0] start_node = np.random.choice(avail_start_nodes) @@ -905,4 +903,5 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 def argsort(seq): # http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python return sorted(range(len(seq)), key=seq.__getitem__) + return generator diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index e1647524..fd6e7b88 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -743,8 +743,7 @@ def test_sparse_rail_generator_deterministic(): num_neighb=3, # Number of connections to other cities/intersections seed=215545, # Random seed - grid_mode=True, - enhance_intersection=False + grid_mode=True ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=1, -- GitLab