From c02330107d29fe10a7f72b1fde19c7b65e248b5c Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Thu, 29 Aug 2019 14:33:09 +0200 Subject: [PATCH] bugfix #141: check_path_exists and tests --- flatland/core/transition_map.py | 40 +++++++++++++++++++++++++--- flatland/envs/schedule_generators.py | 16 +++++------ 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 2ae34607..bb954998 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -301,9 +301,16 @@ class GridTransitionMap(TransitionMap): def is_dead_end(self, rcPos): """ - Check if the cell is a dead-end - :param rcPos: tuple(row, column) with grid coordinate - :return: False : if not a dead-end else True + Check if the cell is a dead-end. + + Parameters + ---------- + rcPos: Tuple[int,int] + tuple(row, column) with grid coordinate + Returns + ------- + boolean + True if and only if the cell is a dead-end. """ nbits = 0 tmp = self.get_full_transitions(rcPos[0], rcPos[1]) @@ -312,6 +319,33 @@ class GridTransitionMap(TransitionMap): tmp = tmp >> 1 return nbits == 1 + def is_simple_turn(self, rcPos): + """ + Check if the cell is a left/right simple turn + + Parameters + ---------- + rcPos: Tuple[int,int] + tuple(row, column) with grid coordinate + Returns + ------- + boolean + True if and only if the cell is a left/right simple turn. + """ + tmp = self.get_full_transitions(rcPos[0], rcPos[1]) + + def is_simple_turn(trans): + all_simple_turns = set() + for trans in [int('0100000000000010', 2), # Case 1b (8) - simple turn right + int('0001001000000000', 2) # Case 1c (9) - simple turn left]: + ]: + for _ in range(3): + trans = self.transitions.rotate_transition(trans, rotation=90) + all_simple_turns.add(trans) + return trans in all_simple_turns + + return is_simple_turn(tmp) + def check_path_exists(self, start, direction, end): # print("_path_exists({},{},{}".format(start, direction, end)) # BFS - Check if a path exists between the 2 nodes diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index d1539aa7..a8a5be77 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -152,8 +152,12 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> re_generate = True cnt = 0 - while re_generate and cnt < 100: + while re_generate: cnt += 1 + if cnt >= 1: + print("re_generate cnt={}".format(cnt)) + if cnt > 1000: + raise Exception("After 1000 re_generates still not success, giving up.") # update position for i in range(num_agents): if update_agents[i] == 1: @@ -171,8 +175,6 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> re_generate = False for i in range(num_agents): valid_movements = [] - if rail.is_dead_end(agents_position[i]): - print(" dead_end", agents_position[i]) for direction in range(4): position = agents_position[i] moves = rail.get_transitions(position[0], position[1], direction) @@ -183,23 +185,19 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> valid_starting_directions = [] for m in valid_movements: new_position = get_new_position(agents_position[i], m[1]) - if m[0] not in valid_starting_directions and rail.check_path_exists(new_position, m[0], + if m[0] not in valid_starting_directions and rail.check_path_exists(new_position, m[1], agents_target[i]): valid_starting_directions.append(m[0]) if len(valid_starting_directions) == 0: update_agents[i] = 1 - print("reset position for agents:", i, agents_position[i], agents_target[i]) - print(" dead_end", rail.is_dead_end(agents_position[i])) + warnings.warn("reset position for agents:", i, agents_position[i], agents_target[i]) re_generate = True break else: agents_direction[i] = valid_starting_directions[ np.random.choice(len(valid_starting_directions), 1)[0]] - if re_generate: - print("re_generate") - agents_speed = speed_initialization_helper(num_agents, speed_ratio_map) return agents_position, agents_direction, agents_target, agents_speed -- GitLab