Skip to content
Snippets Groups Projects
Commit 34df9cab authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

refactoring and clean up

parent 7dceb333
No related branches found
No related tags found
No related merge requests found
......@@ -9,7 +9,8 @@ from flatland.core.grid.grid4_astar import a_star
from flatland.core.grid.grid4_utils import get_direction, mirror
def connect_rail(rail_trans, rail_array, start, end):
def connect_basic_operation(rail_trans, rail_array, start, end,
flip_start_node_trans=False, flip_end_node_trans=False):
"""
Creates a new path [start,end] in rail_array, based on rail_trans.
"""
......@@ -28,8 +29,11 @@ def connect_rail(rail_trans, rail_array, start, end):
if index == 0:
if new_trans == 0:
# end-point
# need to flip direction because of how end points are defined
new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
if flip_start_node_trans:
# need to flip direction because of how end points are defined
new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
else:
new_trans = 0
else:
# into existing rail
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
......@@ -45,7 +49,10 @@ def connect_rail(rail_trans, rail_array, start, end):
new_trans_e = rail_array[end_pos]
if new_trans_e == 0:
# end-point
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
if flip_end_node_trans:
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
else:
new_trans_e = 0
else:
# into existing rail
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
......@@ -55,141 +62,17 @@ def connect_rail(rail_trans, rail_array, start, end):
return path
def connect_nodes(rail_trans, rail_array, start, end):
"""
Creates a new path [start,end] in rail_array, based on rail_trans.
"""
# in the worst case we will need to do a A* search, so we might as well set that up
path = a_star(rail_trans, rail_array, start, end)
if len(path) < 2:
return []
current_dir = get_direction(path[0], path[1])
end_pos = path[-1]
for index in range(len(path) - 1):
current_pos = path[index]
new_pos = path[index + 1]
new_dir = get_direction(current_pos, new_pos)
new_trans = rail_array[current_pos]
if index == 0:
if new_trans == 0:
# end-point
# don't set any transition at node yet
new_trans = 0
else:
# into existing rail
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
else:
# set the forward path
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
# set the backwards path
new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
rail_array[current_pos] = new_trans
if new_pos == end_pos:
# setup end pos setup
new_trans_e = rail_array[end_pos]
if new_trans_e == 0:
# end-point
# don't set any transition at node yet
def connect_rail(rail_trans, rail_array, start, end):
return connect_basic_operation(rail_trans, rail_array, start, end, True, True)
new_trans_e = 0
else:
# into existing rail
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
rail_array[end_pos] = new_trans_e
current_dir = new_dir
return path
def connect_nodes(rail_trans, rail_array, start, end):
return connect_basic_operation(rail_trans, rail_array, start, end, False, False)
def connect_from_nodes(rail_trans, rail_array, start, end):
"""
Creates a new path [start,end] in rail_array, based on rail_trans.
"""
# in the worst case we will need to do a A* search, so we might as well set that up
path = a_star(rail_trans, rail_array, start, end)
if len(path) < 2:
return []
current_dir = get_direction(path[0], path[1])
end_pos = path[-1]
for index in range(len(path) - 1):
current_pos = path[index]
new_pos = path[index + 1]
new_dir = get_direction(current_pos, new_pos)
new_trans = rail_array[current_pos]
if index == 0:
if new_trans == 0:
# end-point
# need to flip direction because of how end points are defined
new_trans = 0
else:
# into existing rail
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
else:
# set the forward path
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
# set the backwards path
new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
rail_array[current_pos] = new_trans
if new_pos == end_pos:
# setup end pos setup
new_trans_e = rail_array[end_pos]
if new_trans_e == 0:
# end-point
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
else:
# into existing rail
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
rail_array[end_pos] = new_trans_e
current_dir = new_dir
return path
return connect_basic_operation(rail_trans, rail_array, start, end, False, True)
def connect_to_nodes(rail_trans, rail_array, start, end):
"""
Creates a new path [start,end] in rail_array, based on rail_trans.
"""
# in the worst case we will need to do a A* search, so we might as well set that up
path = a_star(rail_trans, rail_array, start, end)
if len(path) < 2:
return []
current_dir = get_direction(path[0], path[1])
end_pos = path[-1]
for index in range(len(path) - 1):
current_pos = path[index]
new_pos = path[index + 1]
new_dir = get_direction(current_pos, new_pos)
new_trans = rail_array[current_pos]
if index == 0:
if new_trans == 0:
# end-point
# need to flip direction because of how end points are defined
new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
else:
# into existing rail
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
else:
# set the forward path
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
# set the backwards path
new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
rail_array[current_pos] = new_trans
if new_pos == end_pos:
# setup end pos setup
new_trans_e = rail_array[end_pos]
if new_trans_e == 0:
# end-point
new_trans_e = 0
else:
# into existing rail
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
rail_array[end_pos] = new_trans_e
current_dir = new_dir
return path
return connect_basic_operation(rail_trans, rail_array, start, end, True, False)
......@@ -27,6 +27,7 @@ def test_sparse_rail_generator():
env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
env_renderer.gl.save_image("./sparse_generator_false.png")
# TODO test assertions!
env_renderer.close_window()
def test_rail_env_action_required_info():
......@@ -108,7 +109,7 @@ def test_rail_env_action_required_info():
if done_always_action['__all__']:
break
env_renderer.close_window()
def test_rail_env_malfunction_speed_info():
np.random.seed(0)
......@@ -158,7 +159,7 @@ def test_rail_env_malfunction_speed_info():
if done['__all__']:
break
env_renderer.close_window()
def test_sparse_generator_with_too_man_cities_does_not_break_down():
np.random.seed(0)
......
......@@ -45,14 +45,12 @@ def test_render_env(save_new_images=False):
oEnv.rail.load_transition_map('env_data.tests', "test1.npy")
oRT = rt.RenderTool(oEnv, gl="PILSVG")
oRT.render_env(show=False)
checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images)
oRT = rt.RenderTool(oEnv, gl="PIL")
oRT.render_env()
checkFrozenImage(oRT, "basic-env-PIL.npz", resave=save_new_images)
def main():
if len(sys.argv) == 2 and sys.argv[1] == "save":
test_render_env(save_new_images=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment