diff --git a/examples/play_model.py b/examples/play_model.py index a61954ea57bfa63485f188a9b6964649e6515d56..cba087bccfc60ccb10ec4aef1d13c806369303da 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -95,21 +95,10 @@ def main(render=True, delay=0.0): random.seed(1) np.random.seed(1) - # Example generate a rail given a manual specification, - # a map of tuples (cell_type, rotation) - # transition_probability = [0.5, # empty cell - Case 0 - # 1.0, # Case 1 - straight - # 1.0, # Case 2 - simple switch - # 0.3, # Case 3 - diamond crossing - # 0.5, # Case 4 - single slip - # 0.5, # Case 5 - double slip - # 0.2, # Case 6 - symmetrical - # 0.0] # Case 7 - dead end - # Example generate a random rail env = RailEnv(width=15, height=15, rail_generator=complex_rail_generator(nr_start_goal=15, min_dist=5), - number_of_agents=1) + number_of_agents=5) if render: env_renderer = RenderTool(env, gl="QT") diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py index ac29408051efd4bc85f8120b02f7dfd59b14dee9..1ad6c6de0fa5bc3db8daf0f40b74be5433e69fba 100644 --- a/flatland/envs/env_utils.py +++ b/flatland/envs/env_utils.py @@ -75,7 +75,6 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p # check if matches existing layout new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) # new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) - # rail_trans.print(new_trans) else: # set the forward path new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) @@ -91,20 +90,11 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p # check if matches existing layout new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) # new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1) - # print("end:", end_pos, current_pos) - # rail_trans.print(new_trans_e) - # print("========> end trans") - # rail_trans.print(new_trans_e) if not rail_trans.is_valid(new_trans_e): - # print("end failed", end_pos, current_pos) return False - # else: - # print("end ok!", end_pos, current_pos) # is transition is valid? - # print("=======> trans") - # rail_trans.print(new_trans) return rail_trans.is_valid(new_trans) @@ -141,10 +131,6 @@ def a_star(rail_trans, rail_array, start, end): open_list.pop(current_index) closed_list.append(current_node) - # print("a*:", current_node.pos) - # for cn in closed_list: - # print("closed:", cn.pos) - # found the goal if current_node == end_node: path = [] @@ -169,14 +155,8 @@ def a_star(rail_trans, rail_array, start, end): node_pos[1] < 0: continue - # validate positions - # debug: avoid all current rails - # if rail_array.item(node_pos) != 0: - # continue - # validate positions if not validate_new_transition(rail_trans, rail_array, prev_pos, current_node.pos, node_pos, end_node.pos): - # print("A*: transition invalid") continue # create new node @@ -216,7 +196,7 @@ def a_star(rail_trans, rail_array, start, end): path.append(current.pos) current = current.parent # return reversed path - print("partial:", start, end, path[::-1]) + # print("partial:", start, end, path[::-1]) return path[::-1] @@ -226,9 +206,8 @@ def connect_rail(rail_trans, rail_array, start, end): """ # in the worst case we will need to do a A* search, so we might as well set that up path = a_star(rail_trans, rail_array, start, end) - # print("connecting path", path) if len(path) < 2: - return + return [] current_dir = get_direction(path[0], path[1]) end_pos = path[-1] for index in range(len(path) - 1): @@ -246,7 +225,6 @@ def connect_rail(rail_trans, rail_array, start, end): # into existing rail new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) # new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) - pass else: # set the forward path new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) @@ -267,6 +245,7 @@ def connect_rail(rail_trans, rail_array, start, end): rail_array[end_pos] = new_trans_e current_dir = new_dir + return path def distance_on_rail(pos1, pos2): diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index de75b70b62145a911d64effa3dbad162c928289c..29f8c6f4151a35b55f99f8478383195384b7a9e7 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -5,7 +5,8 @@ import numpy as np from flatland.core.transitions import Grid8Transitions, RailEnvTransitions from flatland.core.transition_map import GridTransitionMap -from flatland.envs.env_utils import distance_on_rail, connect_rail, get_rnd_agents_pos_tgt_dir_on_rail +from flatland.envs.env_utils import distance_on_rail, connect_rail, get_direction, mirror +from flatland.envs.env_utils import get_rnd_agents_pos_tgt_dir_on_rail def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): @@ -25,7 +26,9 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): def generator(width, height, agents_handles, num_resets=0): rail_trans = RailEnvTransitions() - rail_array = np.zeros(shape=(width, height), dtype=np.uint16) + grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) + rail_array = grid_map.grid + rail_array.fill(0) np.random.seed(seed + num_resets) @@ -69,8 +72,11 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): # start_goal = [] - for _ in range(nr_start_goal): - sanity_max = 9000 + start_dir = [] + nr_created = 0 + created_sanity = 0 + sanity_max = 9000 + while nr_created < nr_start_goal and created_sanity < sanity_max: for _ in range(sanity_max): start = (np.random.randint(0, width), np.random.randint(0, height)) goal = (np.random.randint(0, height), np.random.randint(0, height)) @@ -98,20 +104,26 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): if check_all_dist(sg_new): break - start_goal.append([start, goal]) - connect_rail(rail_trans, rail_array, start, goal) - print("Created #", len(start_goal), "pairs") - # print(start_goal) + new_path = connect_rail(rail_trans, rail_array, start, goal) + if len(new_path) >= 2: + nr_created += 1 + # print(":::: path: ", new_path) + start_goal.append([start, goal]) + start_dir.append(mirror(get_direction(new_path[0], new_path[1]))) + else: + # after too many failures we will give up + # print("failed...") + created_sanity += 1 - return_rail = GridTransitionMap(width=width, height=height, transitions=rail_trans) - return_rail.grid = rail_array + # print("Created #", len(start_goal), "pairs") + # print(start_goal) - # TODO: return agents_position, agents_direction and agents_target! - # NOTE: the initial direction must be such that the target can be reached. - # See env_utils.get_rnd_agents_pos_tgt_dir_on_rail() for hints, if required. + agents_position = [sg[0] for sg in start_goal] + agents_target = [sg[1] for sg in start_goal] + agents_direction = start_dir - return return_rail, [], [], [] + return grid_map, agents_position, agents_direction, agents_target return generator diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index c54b2fed0a1a9420b33b33183da7f471ace6c2ec..aba58fb2d86163242890eee8d4dc21c57405eac2 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -553,7 +553,7 @@ class RenderTool(object): bDeadEnd = nbits == 1 if not bCellValid: - print("invalid:", r, c) + # print("invalid:", r, c) self.gl.scatter(*xyCentre, color="r", s=50) for orientation in range(4): # ori is where we're heading @@ -659,7 +659,7 @@ class RenderTool(object): self.gl.endFrame() t2 = time.time() - print(t2 - t1, "seconds") + # print(t2 - t1, "seconds") if show: self.gl.show(block=False)