From 34df9cab1dfad20a3a26ad5dacdc5a3015a02837 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Tue, 17 Sep 2019 07:44:34 +0200 Subject: [PATCH] refactoring and clean up --- flatland/envs/grid4_generators_utils.py | 151 ++---------------- ...est_flatland_envs_sparse_rail_generator.py | 5 +- tests/test_flatland_utils_rendertools.py | 4 +- 3 files changed, 21 insertions(+), 139 deletions(-) diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index 996bd73a..ab0d1a94 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -9,7 +9,8 @@ from flatland.core.grid.grid4_astar import a_star from flatland.core.grid.grid4_utils import get_direction, mirror -def connect_rail(rail_trans, rail_array, start, end): +def connect_basic_operation(rail_trans, rail_array, start, end, + flip_start_node_trans=False, flip_end_node_trans=False): """ Creates a new path [start,end] in rail_array, based on rail_trans. """ @@ -28,8 +29,11 @@ def connect_rail(rail_trans, rail_array, start, end): if index == 0: if new_trans == 0: # end-point - # need to flip direction because of how end points are defined - new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) + if flip_start_node_trans: + # need to flip direction because of how end points are defined + new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) + else: + new_trans = 0 else: # into existing rail new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) @@ -45,7 +49,10 @@ def connect_rail(rail_trans, rail_array, start, end): new_trans_e = rail_array[end_pos] if new_trans_e == 0: # end-point - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) + if flip_end_node_trans: + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) + else: + new_trans_e = 0 else: # into existing rail new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) @@ -55,141 +62,17 @@ def connect_rail(rail_trans, rail_array, start, end): return path -def connect_nodes(rail_trans, rail_array, start, end): - """ - Creates a new path [start,end] in rail_array, based on rail_trans. - """ - # in the worst case we will need to do a A* search, so we might as well set that up - path = a_star(rail_trans, rail_array, start, end) - if len(path) < 2: - return [] - current_dir = get_direction(path[0], path[1]) - end_pos = path[-1] - for index in range(len(path) - 1): - current_pos = path[index] - new_pos = path[index + 1] - new_dir = get_direction(current_pos, new_pos) - - new_trans = rail_array[current_pos] - if index == 0: - if new_trans == 0: - # end-point - # don't set any transition at node yet - new_trans = 0 - else: - # into existing rail - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - else: - # set the forward path - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - # set the backwards path - new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) - rail_array[current_pos] = new_trans - - if new_pos == end_pos: - # setup end pos setup - new_trans_e = rail_array[end_pos] - if new_trans_e == 0: - # end-point - # don't set any transition at node yet +def connect_rail(rail_trans, rail_array, start, end): + return connect_basic_operation(rail_trans, rail_array, start, end, True, True) - new_trans_e = 0 - else: - # into existing rail - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) - rail_array[end_pos] = new_trans_e - current_dir = new_dir - return path +def connect_nodes(rail_trans, rail_array, start, end): + return connect_basic_operation(rail_trans, rail_array, start, end, False, False) def connect_from_nodes(rail_trans, rail_array, start, end): - """ - Creates a new path [start,end] in rail_array, based on rail_trans. - """ - # in the worst case we will need to do a A* search, so we might as well set that up - path = a_star(rail_trans, rail_array, start, end) - if len(path) < 2: - return [] - current_dir = get_direction(path[0], path[1]) - end_pos = path[-1] - for index in range(len(path) - 1): - current_pos = path[index] - new_pos = path[index + 1] - new_dir = get_direction(current_pos, new_pos) - - new_trans = rail_array[current_pos] - if index == 0: - if new_trans == 0: - # end-point - # need to flip direction because of how end points are defined - new_trans = 0 - else: - # into existing rail - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - else: - # set the forward path - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - # set the backwards path - new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) - rail_array[current_pos] = new_trans - - if new_pos == end_pos: - # setup end pos setup - new_trans_e = rail_array[end_pos] - if new_trans_e == 0: - # end-point - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) - else: - # into existing rail - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) - rail_array[end_pos] = new_trans_e - - current_dir = new_dir - return path + return connect_basic_operation(rail_trans, rail_array, start, end, False, True) def connect_to_nodes(rail_trans, rail_array, start, end): - """ - Creates a new path [start,end] in rail_array, based on rail_trans. - """ - # in the worst case we will need to do a A* search, so we might as well set that up - path = a_star(rail_trans, rail_array, start, end) - if len(path) < 2: - return [] - current_dir = get_direction(path[0], path[1]) - end_pos = path[-1] - for index in range(len(path) - 1): - current_pos = path[index] - new_pos = path[index + 1] - new_dir = get_direction(current_pos, new_pos) - - new_trans = rail_array[current_pos] - if index == 0: - if new_trans == 0: - # end-point - # need to flip direction because of how end points are defined - new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) - else: - # into existing rail - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - else: - # set the forward path - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - # set the backwards path - new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) - rail_array[current_pos] = new_trans - - if new_pos == end_pos: - # setup end pos setup - new_trans_e = rail_array[end_pos] - if new_trans_e == 0: - # end-point - new_trans_e = 0 - else: - # into existing rail - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) - rail_array[end_pos] = new_trans_e - - current_dir = new_dir - return path + return connect_basic_operation(rail_trans, rail_array, start, end, True, False) diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index a0e2b995..a2992a7a 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -27,6 +27,7 @@ def test_sparse_rail_generator(): env_renderer.render_env(show=True, show_observations=True, show_predictions=False) env_renderer.gl.save_image("./sparse_generator_false.png") # TODO test assertions! + env_renderer.close_window() def test_rail_env_action_required_info(): @@ -108,7 +109,7 @@ def test_rail_env_action_required_info(): if done_always_action['__all__']: break - + env_renderer.close_window() def test_rail_env_malfunction_speed_info(): np.random.seed(0) @@ -158,7 +159,7 @@ def test_rail_env_malfunction_speed_info(): if done['__all__']: break - + env_renderer.close_window() def test_sparse_generator_with_too_man_cities_does_not_break_down(): np.random.seed(0) diff --git a/tests/test_flatland_utils_rendertools.py b/tests/test_flatland_utils_rendertools.py index 8248c675..853b025f 100644 --- a/tests/test_flatland_utils_rendertools.py +++ b/tests/test_flatland_utils_rendertools.py @@ -45,14 +45,12 @@ def test_render_env(save_new_images=False): oEnv.rail.load_transition_map('env_data.tests', "test1.npy") oRT = rt.RenderTool(oEnv, gl="PILSVG") oRT.render_env(show=False) - checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images) oRT = rt.RenderTool(oEnv, gl="PIL") oRT.render_env() checkFrozenImage(oRT, "basic-env-PIL.npz", resave=save_new_images) - - + def main(): if len(sys.argv) == 2 and sys.argv[1] == "save": test_render_env(save_new_images=True) -- GitLab