diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 2ae3460716eeabc53bc1833efb34d5d3f71948cb..bb954998688772a7ce69e5228cff3e16d037f2af 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -301,9 +301,16 @@ class GridTransitionMap(TransitionMap): def is_dead_end(self, rcPos): """ - Check if the cell is a dead-end - :param rcPos: tuple(row, column) with grid coordinate - :return: False : if not a dead-end else True + Check if the cell is a dead-end. + + Parameters + ---------- + rcPos: Tuple[int,int] + tuple(row, column) with grid coordinate + Returns + ------- + boolean + True if and only if the cell is a dead-end. """ nbits = 0 tmp = self.get_full_transitions(rcPos[0], rcPos[1]) @@ -312,6 +319,33 @@ class GridTransitionMap(TransitionMap): tmp = tmp >> 1 return nbits == 1 + def is_simple_turn(self, rcPos): + """ + Check if the cell is a left/right simple turn + + Parameters + ---------- + rcPos: Tuple[int,int] + tuple(row, column) with grid coordinate + Returns + ------- + boolean + True if and only if the cell is a left/right simple turn. + """ + tmp = self.get_full_transitions(rcPos[0], rcPos[1]) + + def is_simple_turn(trans): + all_simple_turns = set() + for trans in [int('0100000000000010', 2), # Case 1b (8) - simple turn right + int('0001001000000000', 2) # Case 1c (9) - simple turn left]: + ]: + for _ in range(3): + trans = self.transitions.rotate_transition(trans, rotation=90) + all_simple_turns.add(trans) + return trans in all_simple_turns + + return is_simple_turn(tmp) + def check_path_exists(self, start, direction, end): # print("_path_exists({},{},{}".format(start, direction, end)) # BFS - Check if a path exists between the 2 nodes diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index d1539aa7a2283ee908acfa4fa37b9659c91ed224..a8a5be77b75000900be67d2fc64dd22b46b68fae 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -152,8 +152,12 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> re_generate = True cnt = 0 - while re_generate and cnt < 100: + while re_generate: cnt += 1 + if cnt >= 1: + print("re_generate cnt={}".format(cnt)) + if cnt > 1000: + raise Exception("After 1000 re_generates still not success, giving up.") # update position for i in range(num_agents): if update_agents[i] == 1: @@ -171,8 +175,6 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> re_generate = False for i in range(num_agents): valid_movements = [] - if rail.is_dead_end(agents_position[i]): - print(" dead_end", agents_position[i]) for direction in range(4): position = agents_position[i] moves = rail.get_transitions(position[0], position[1], direction) @@ -183,23 +185,19 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> valid_starting_directions = [] for m in valid_movements: new_position = get_new_position(agents_position[i], m[1]) - if m[0] not in valid_starting_directions and rail.check_path_exists(new_position, m[0], + if m[0] not in valid_starting_directions and rail.check_path_exists(new_position, m[1], agents_target[i]): valid_starting_directions.append(m[0]) if len(valid_starting_directions) == 0: update_agents[i] = 1 - print("reset position for agents:", i, agents_position[i], agents_target[i]) - print(" dead_end", rail.is_dead_end(agents_position[i])) + warnings.warn("reset position for agents:", i, agents_position[i], agents_target[i]) re_generate = True break else: agents_direction[i] = valid_starting_directions[ np.random.choice(len(valid_starting_directions), 1)[0]] - if re_generate: - print("re_generate") - agents_speed = speed_initialization_helper(num_agents, speed_ratio_map) return agents_position, agents_direction, agents_target, agents_speed