diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index bb954998688772a7ce69e5228cff3e16d037f2af..232d6fdab02c57da95bf04c631e4905986c71327 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -8,6 +8,7 @@ from numpy import array from flatland.core.grid.grid4 import Grid4Transitions from flatland.core.grid.grid4_utils import get_new_position +from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transitions import Transitions @@ -482,6 +483,14 @@ class GridTransitionMap(TransitionMap): grcPos = array(rcPos) grcMax = self.grid.shape + # Transition elements + transitions = RailEnvTransitions() + cells = transitions.transition_list + simple_switch_east_south = transitions.rotate_transition(cells[10], 90) + simple_switch_west_south = transitions.rotate_transition(cells[2], 270) + symmetrical = cells[6] + double_slip = cells[5] + three_way_transitions = [simple_switch_east_south, simple_switch_west_south, symmetrical] # loop over available outbound directions (indices) for rcPos self.set_transitions(rcPos, 0) @@ -517,25 +526,18 @@ class GridTransitionMap(TransitionMap): self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1) self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1) - # Find feasible connection fro three entries + # Find feasible connection for three entries if number_of_incoming == 3: + transition = np.random.choice(three_way_transitions, 1) hole = np.argwhere(incoming_connections < 1)[0][0] - connect_directions = [(hole + 1) % 4, (hole + 2) % 4, (hole + 3) % 4] - self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1) - self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[2], 1) - self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1) - self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[2])), connect_directions[0], 1) - # Make a cross + transition = transitions.rotate_transition(transition, int(hole * 90)) + self.set_transitions((rcPos[0], rcPos[1]), transition) + + # Make a double slip switch if number_of_incoming == 4: - connect_directions = np.arange(4) - self.set_transition((rcPos[0], rcPos[1], connect_directions[0]), connect_directions[0], 1) - self.set_transition((rcPos[0], rcPos[1], connect_directions[0]), connect_directions[1], 1) - self.set_transition((rcPos[0], rcPos[1], connect_directions[1]), connect_directions[0], 1) - self.set_transition((rcPos[0], rcPos[1], connect_directions[1]), connect_directions[1], 1) - self.set_transition((rcPos[0], rcPos[1], connect_directions[2]), connect_directions[2], 1) - self.set_transition((rcPos[0], rcPos[1], connect_directions[2]), connect_directions[3], 1) - self.set_transition((rcPos[0], rcPos[1], connect_directions[3]), connect_directions[2], 1) - self.set_transition((rcPos[0], rcPos[1], connect_directions[3]), connect_directions[3], 1) + rotation = np.random.randint(2) + transition = transitions.rotate_transition(double_slip, int(rotation * 90)) + self.set_transitions((rcPos[0], rcPos[1]), transition) return True