Skip to content
Snippets Groups Projects
Commit 47be1c5f authored by maljx's avatar maljx
Browse files

updated complex generator to return start,goal,start dir

parent da35282a
No related branches found
No related tags found
No related merge requests found
......@@ -95,21 +95,10 @@ def main(render=True, delay=0.0):
random.seed(1)
np.random.seed(1)
# Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation)
# transition_probability = [0.5, # empty cell - Case 0
# 1.0, # Case 1 - straight
# 1.0, # Case 2 - simple switch
# 0.3, # Case 3 - diamond crossing
# 0.5, # Case 4 - single slip
# 0.5, # Case 5 - double slip
# 0.2, # Case 6 - symmetrical
# 0.0] # Case 7 - dead end
# Example generate a random rail
env = RailEnv(width=15, height=15,
rail_generator=complex_rail_generator(nr_start_goal=15, min_dist=5),
number_of_agents=1)
number_of_agents=5)
if render:
env_renderer = RenderTool(env, gl="QT")
......
......@@ -75,7 +75,6 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p
# check if matches existing layout
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
# new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
# rail_trans.print(new_trans)
else:
# set the forward path
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
......@@ -91,20 +90,11 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p
# check if matches existing layout
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
# new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1)
# print("end:", end_pos, current_pos)
# rail_trans.print(new_trans_e)
# print("========> end trans")
# rail_trans.print(new_trans_e)
if not rail_trans.is_valid(new_trans_e):
# print("end failed", end_pos, current_pos)
return False
# else:
# print("end ok!", end_pos, current_pos)
# is transition is valid?
# print("=======> trans")
# rail_trans.print(new_trans)
return rail_trans.is_valid(new_trans)
......@@ -141,10 +131,6 @@ def a_star(rail_trans, rail_array, start, end):
open_list.pop(current_index)
closed_list.append(current_node)
# print("a*:", current_node.pos)
# for cn in closed_list:
# print("closed:", cn.pos)
# found the goal
if current_node == end_node:
path = []
......@@ -169,14 +155,8 @@ def a_star(rail_trans, rail_array, start, end):
node_pos[1] < 0:
continue
# validate positions
# debug: avoid all current rails
# if rail_array.item(node_pos) != 0:
# continue
# validate positions
if not validate_new_transition(rail_trans, rail_array, prev_pos, current_node.pos, node_pos, end_node.pos):
# print("A*: transition invalid")
continue
# create new node
......@@ -216,7 +196,7 @@ def a_star(rail_trans, rail_array, start, end):
path.append(current.pos)
current = current.parent
# return reversed path
print("partial:", start, end, path[::-1])
# print("partial:", start, end, path[::-1])
return path[::-1]
......@@ -226,9 +206,8 @@ def connect_rail(rail_trans, rail_array, start, end):
"""
# in the worst case we will need to do a A* search, so we might as well set that up
path = a_star(rail_trans, rail_array, start, end)
# print("connecting path", path)
if len(path) < 2:
return
return []
current_dir = get_direction(path[0], path[1])
end_pos = path[-1]
for index in range(len(path) - 1):
......@@ -246,7 +225,6 @@ def connect_rail(rail_trans, rail_array, start, end):
# into existing rail
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
# new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
pass
else:
# set the forward path
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
......@@ -267,6 +245,7 @@ def connect_rail(rail_trans, rail_array, start, end):
rail_array[end_pos] = new_trans_e
current_dir = new_dir
return path
def distance_on_rail(pos1, pos2):
......
......@@ -5,7 +5,8 @@ import numpy as np
from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.env_utils import distance_on_rail, connect_rail, get_rnd_agents_pos_tgt_dir_on_rail
from flatland.envs.env_utils import distance_on_rail, connect_rail, get_direction, mirror
from flatland.envs.env_utils import get_rnd_agents_pos_tgt_dir_on_rail
def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
......@@ -25,7 +26,9 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
def generator(width, height, agents_handles, num_resets=0):
rail_trans = RailEnvTransitions()
rail_array = np.zeros(shape=(width, height), dtype=np.uint16)
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
rail_array = grid_map.grid
rail_array.fill(0)
np.random.seed(seed + num_resets)
......@@ -69,8 +72,11 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
#
start_goal = []
for _ in range(nr_start_goal):
sanity_max = 9000
start_dir = []
nr_created = 0
created_sanity = 0
sanity_max = 9000
while nr_created < nr_start_goal and created_sanity < sanity_max:
for _ in range(sanity_max):
start = (np.random.randint(0, width), np.random.randint(0, height))
goal = (np.random.randint(0, height), np.random.randint(0, height))
......@@ -98,20 +104,26 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
if check_all_dist(sg_new):
break
start_goal.append([start, goal])
connect_rail(rail_trans, rail_array, start, goal)
print("Created #", len(start_goal), "pairs")
# print(start_goal)
new_path = connect_rail(rail_trans, rail_array, start, goal)
if len(new_path) >= 2:
nr_created += 1
# print(":::: path: ", new_path)
start_goal.append([start, goal])
start_dir.append(mirror(get_direction(new_path[0], new_path[1])))
else:
# after too many failures we will give up
# print("failed...")
created_sanity += 1
return_rail = GridTransitionMap(width=width, height=height, transitions=rail_trans)
return_rail.grid = rail_array
# print("Created #", len(start_goal), "pairs")
# print(start_goal)
# TODO: return agents_position, agents_direction and agents_target!
# NOTE: the initial direction must be such that the target can be reached.
# See env_utils.get_rnd_agents_pos_tgt_dir_on_rail() for hints, if required.
agents_position = [sg[0] for sg in start_goal]
agents_target = [sg[1] for sg in start_goal]
agents_direction = start_dir
return return_rail, [], [], []
return grid_map, agents_position, agents_direction, agents_target
return generator
......
......@@ -553,7 +553,7 @@ class RenderTool(object):
bDeadEnd = nbits == 1
if not bCellValid:
print("invalid:", r, c)
# print("invalid:", r, c)
self.gl.scatter(*xyCentre, color="r", s=50)
for orientation in range(4): # ori is where we're heading
......@@ -659,7 +659,7 @@ class RenderTool(object):
self.gl.endFrame()
t2 = time.time()
print(t2 - t1, "seconds")
# print(t2 - t1, "seconds")
if show:
self.gl.show(block=False)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment