From 6a8f55117cd0d3c975f51808e836a958caf08b23 Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Sat, 17 Aug 2019 11:17:27 -0400 Subject: [PATCH] Fixed typo and formatting --- flatland/envs/generators.py | 29 ++++++++++--------- flatland/envs/grid4_generators_utils.py | 1 + ...test_flatland_env_sparse_rail_generator.py | 15 ++++------ 3 files changed, 21 insertions(+), 24 deletions(-) diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index a0eb7ca7..50759734 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -543,7 +543,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11): return generator -def realistic_rail_generator(nr_start_goal=1, seed=0): +def realistic_rail_generator(nr_start_goal=1, seed=0): """ Parameters ------- @@ -682,7 +682,7 @@ def realistic_rail_generator(nr_start_goal=1, seed=0): idx_forward = [] idx_backward = [] - idx_target=0 + idx_target = 0 for off_set_loop in range(len(x_offsets)): off_set = x_offsets[off_set_loop] # second track @@ -739,14 +739,17 @@ def realistic_rail_generator(nr_start_goal=1, seed=0): (x_offsets[off_set_loop] - 1, 0), (x_offsets[off_set_loop] - 2, 0)) - for nbr_track_loop in range(max_n_track_seg-1): + for nbr_track_loop in range(max_n_track_seg - 1): if len(data) < 2 * n_track_seg + 1: break x = np.sort(np.random.choice(data, 2 * n_track_seg, False)).astype(int) data = [] for x_loop in range(int(len(x) / 2)): - start = (max(0, min(off_set + nbr_track_loop + 1, height - 1)), max(0, min(x[2 * x_loop], width - 1))) - goal = (max(0, min(off_set + nbr_track_loop + 1, height - 1)), max(0, min(x[2 * x_loop + 1], width - 1))) + start = ( + max(0, min(off_set + nbr_track_loop + 1, height - 1)), max(0, min(x[2 * x_loop], width - 1))) + goal = ( + max(0, min(off_set + nbr_track_loop + 1, height - 1)), + max(0, min(x[2 * x_loop + 1], width - 1))) d = np.arange(x[2 * x_loop] + 1, x[2 * x_loop + 1] - 1, 2) data.extend(d) @@ -767,15 +770,15 @@ def realistic_rail_generator(nr_start_goal=1, seed=0): agents_directions_backward.append(([1, 3][off_set_loop % 2])) idx_backward.append(idx_target) - add_pos = (int((start[0] + goal[0]) / 2), int((2*start[1] + goal[1]) / 3),idx_target) + add_pos = (int((start[0] + goal[0]) / 2), int((2 * start[1] + goal[1]) / 3), idx_target) agents_targets.append(add_pos) - idx_target+=1 + idx_target += 1 agents_position = [] agents_target = [] agents_direction = [] - for a in range(min(len(agents_targets),num_agents)): + for a in range(min(len(agents_targets), num_agents)): t = np.random.choice(range(len(agents_targets))) d = agents_targets[t] agents_targets.pop(t) @@ -789,7 +792,7 @@ def realistic_rail_generator(nr_start_goal=1, seed=0): agents_positions_backward.pop(sel) agents_directions_backward.pop(sel) idx_backward.pop(sel) - agents_position.append((p[0],p[1])) + agents_position.append((p[0], p[1])) agents_direction.append(d) else: if len(idx_forward) > 0: @@ -801,10 +804,9 @@ def realistic_rail_generator(nr_start_goal=1, seed=0): agents_positions_forward.pop(sel) agents_directions_forward.pop(sel) idx_forward.pop(sel) - agents_position.append((p[0],p[1])) + agents_position.append((p[0], p[1])) agents_direction.append(d) - return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) return generator @@ -881,7 +883,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation for neighb in connected_neighb_idx: if neighb not in node_stack: node_stack.append(neighb) - new_path = connect_nodes(rail_trans, rail_array, node_positions[current_node], node_positions[neighb]) + connect_nodes(rail_trans, rail_array, node_positions[current_node], node_positions[neighb]) node_stack.pop(0) # Place train stations close to the node @@ -908,8 +910,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation train_stations[trainstation_node].append((station_x, station_y)) # Connect train station to the correct node - new_path = connect_from_nodes(rail_trans, rail_array, node_positions[trainstation_node], - (station_x, station_y)) + connect_from_nodes(rail_trans, rail_array, node_positions[trainstation_node], (station_x, station_y)) # Fix all nodes with illegal transition maps for current_node in node_positions: diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index 9116adb6..fbb1f03d 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -194,6 +194,7 @@ def connect_to_nodes(rail_trans, rail_array, start, end): current_dir = new_dir return path + def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents): """ Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target). diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py index 34e6c2b4..66b7ca3b 100644 --- a/tests/test_flatland_env_sparse_rail_generator.py +++ b/tests/test_flatland_env_sparse_rail_generator.py @@ -1,5 +1,3 @@ -import time - import numpy as np from flatland.envs.generators import sparse_rail_generator, realistic_rail_generator @@ -10,18 +8,18 @@ from flatland.utils.rendertools import RenderTool def test_realistic_rail_generator(): for test_loop in range(20): - num_agents = np.random.randint(10,30) - env = RailEnv(width=np.random.randint(40,80), - height=np.random.randint(10,20), - rail_generator=realistic_rail_generator(nr_start_goal=num_agents+1,seed=test_loop), + num_agents = np.random.randint(10, 30) + env = RailEnv(width=np.random.randint(40, 80), + height=np.random.randint(10, 20), + rail_generator=realistic_rail_generator(nr_start_goal=num_agents + 1, seed=test_loop), number_of_agents=num_agents, obs_builder_object=GlobalObsForRailEnv()) # reset to initialize agents_static env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer.render_env(show=True, show_observations=True, show_predictions=False) - time.sleep(2) env_renderer.close_window() + def test_sparse_rail_generator(): env = RailEnv(width=20, height=50, @@ -38,7 +36,4 @@ def test_sparse_rail_generator(): # reset to initialize agents_static env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer.render_env(show=True, show_observations=True, show_predictions=False) - time.sleep(20) - env_renderer.gl.save_image("flatalnd_2_0.png") - time.sleep(1) -- GitLab