diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index 996bd73ad9de598eb162a937c135681675119ad3..ab0d1a94cdd7cfbdc2b29afcae52461b30ef0e44 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -9,7 +9,8 @@ from flatland.core.grid.grid4_astar import a_star from flatland.core.grid.grid4_utils import get_direction, mirror -def connect_rail(rail_trans, rail_array, start, end): +def connect_basic_operation(rail_trans, rail_array, start, end, + flip_start_node_trans=False, flip_end_node_trans=False): """ Creates a new path [start,end] in rail_array, based on rail_trans. """ @@ -28,8 +29,11 @@ def connect_rail(rail_trans, rail_array, start, end): if index == 0: if new_trans == 0: # end-point - # need to flip direction because of how end points are defined - new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) + if flip_start_node_trans: + # need to flip direction because of how end points are defined + new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) + else: + new_trans = 0 else: # into existing rail new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) @@ -45,7 +49,10 @@ def connect_rail(rail_trans, rail_array, start, end): new_trans_e = rail_array[end_pos] if new_trans_e == 0: # end-point - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) + if flip_end_node_trans: + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) + else: + new_trans_e = 0 else: # into existing rail new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) @@ -55,141 +62,17 @@ def connect_rail(rail_trans, rail_array, start, end): return path -def connect_nodes(rail_trans, rail_array, start, end): - """ - Creates a new path [start,end] in rail_array, based on rail_trans. - """ - # in the worst case we will need to do a A* search, so we might as well set that up - path = a_star(rail_trans, rail_array, start, end) - if len(path) < 2: - return [] - current_dir = get_direction(path[0], path[1]) - end_pos = path[-1] - for index in range(len(path) - 1): - current_pos = path[index] - new_pos = path[index + 1] - new_dir = get_direction(current_pos, new_pos) - - new_trans = rail_array[current_pos] - if index == 0: - if new_trans == 0: - # end-point - # don't set any transition at node yet - new_trans = 0 - else: - # into existing rail - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - else: - # set the forward path - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - # set the backwards path - new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) - rail_array[current_pos] = new_trans - - if new_pos == end_pos: - # setup end pos setup - new_trans_e = rail_array[end_pos] - if new_trans_e == 0: - # end-point - # don't set any transition at node yet +def connect_rail(rail_trans, rail_array, start, end): + return connect_basic_operation(rail_trans, rail_array, start, end, True, True) - new_trans_e = 0 - else: - # into existing rail - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) - rail_array[end_pos] = new_trans_e - current_dir = new_dir - return path +def connect_nodes(rail_trans, rail_array, start, end): + return connect_basic_operation(rail_trans, rail_array, start, end, False, False) def connect_from_nodes(rail_trans, rail_array, start, end): - """ - Creates a new path [start,end] in rail_array, based on rail_trans. - """ - # in the worst case we will need to do a A* search, so we might as well set that up - path = a_star(rail_trans, rail_array, start, end) - if len(path) < 2: - return [] - current_dir = get_direction(path[0], path[1]) - end_pos = path[-1] - for index in range(len(path) - 1): - current_pos = path[index] - new_pos = path[index + 1] - new_dir = get_direction(current_pos, new_pos) - - new_trans = rail_array[current_pos] - if index == 0: - if new_trans == 0: - # end-point - # need to flip direction because of how end points are defined - new_trans = 0 - else: - # into existing rail - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - else: - # set the forward path - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - # set the backwards path - new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) - rail_array[current_pos] = new_trans - - if new_pos == end_pos: - # setup end pos setup - new_trans_e = rail_array[end_pos] - if new_trans_e == 0: - # end-point - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) - else: - # into existing rail - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) - rail_array[end_pos] = new_trans_e - - current_dir = new_dir - return path + return connect_basic_operation(rail_trans, rail_array, start, end, False, True) def connect_to_nodes(rail_trans, rail_array, start, end): - """ - Creates a new path [start,end] in rail_array, based on rail_trans. - """ - # in the worst case we will need to do a A* search, so we might as well set that up - path = a_star(rail_trans, rail_array, start, end) - if len(path) < 2: - return [] - current_dir = get_direction(path[0], path[1]) - end_pos = path[-1] - for index in range(len(path) - 1): - current_pos = path[index] - new_pos = path[index + 1] - new_dir = get_direction(current_pos, new_pos) - - new_trans = rail_array[current_pos] - if index == 0: - if new_trans == 0: - # end-point - # need to flip direction because of how end points are defined - new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) - else: - # into existing rail - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - else: - # set the forward path - new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - # set the backwards path - new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) - rail_array[current_pos] = new_trans - - if new_pos == end_pos: - # setup end pos setup - new_trans_e = rail_array[end_pos] - if new_trans_e == 0: - # end-point - new_trans_e = 0 - else: - # into existing rail - new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) - rail_array[end_pos] = new_trans_e - - current_dir = new_dir - return path + return connect_basic_operation(rail_trans, rail_array, start, end, True, False) diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index a0e2b995b35bc2d6984bf6274170a28c18cada70..a2992a7ae4559f17ae03296c3c39a4b69ee10dee 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -27,6 +27,7 @@ def test_sparse_rail_generator(): env_renderer.render_env(show=True, show_observations=True, show_predictions=False) env_renderer.gl.save_image("./sparse_generator_false.png") # TODO test assertions! + env_renderer.close_window() def test_rail_env_action_required_info(): @@ -108,7 +109,7 @@ def test_rail_env_action_required_info(): if done_always_action['__all__']: break - + env_renderer.close_window() def test_rail_env_malfunction_speed_info(): np.random.seed(0) @@ -158,7 +159,7 @@ def test_rail_env_malfunction_speed_info(): if done['__all__']: break - + env_renderer.close_window() def test_sparse_generator_with_too_man_cities_does_not_break_down(): np.random.seed(0) diff --git a/tests/test_flatland_utils_rendertools.py b/tests/test_flatland_utils_rendertools.py index 8248c675995fc5c906e82d8650a5b619e7b038f2..853b025f2ebd39949453f35e0d053e519163237c 100644 --- a/tests/test_flatland_utils_rendertools.py +++ b/tests/test_flatland_utils_rendertools.py @@ -45,14 +45,12 @@ def test_render_env(save_new_images=False): oEnv.rail.load_transition_map('env_data.tests', "test1.npy") oRT = rt.RenderTool(oEnv, gl="PILSVG") oRT.render_env(show=False) - checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images) oRT = rt.RenderTool(oEnv, gl="PIL") oRT.render_env() checkFrozenImage(oRT, "basic-env-PIL.npz", resave=save_new_images) - - + def main(): if len(sys.argv) == 2 and sys.argv[1] == "save": test_render_env(save_new_images=True)