diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 2887610b06d5455bd7612d3c6f0ad386e086c6be..498d98eef653faa97b1f4f9d7a17048f0a8b9b70 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -347,6 +347,14 @@ class TreeObsForRailEnv(ObservationBuilder): # Cummulate the number of agents on branch with other direction other_agent_opposite_direction += 1 + # Check number of possible transitions for agent and total number of transitions in cell (type) + cell_transitions = self.env.rail.get_transitions(*position, direction) + transition_bit = bin(self.env.rail.get_full_transitions(*position)) + total_transitions = transition_bit.count("1") + crossing_found = False + if int(transition_bit, 2) == int('1000010000100001', 2): + crossing_found = True + # Register possible future conflict if self.predictor and num_steps < self.max_prediction_depth: int_position = coordinate_to_position(self.env.width, [position]) @@ -354,7 +362,7 @@ class TreeObsForRailEnv(ObservationBuilder): pre_step = max(0, tot_dist - 1) post_step = min(self.max_prediction_depth - 1, tot_dist + 1) - # Look for opposing paths at distance num_step + # Look for conflicting paths at distance num_step if int_position in np.delete(self.predicted_pos[tot_dist], handle, 0): conflicting_agent = np.where(self.predicted_pos[tot_dist] == int_position) for ca in conflicting_agent[0]: @@ -362,7 +370,8 @@ class TreeObsForRailEnv(ObservationBuilder): potential_conflict = tot_dist if self.env.dones[ca] and tot_dist < potential_conflict: potential_conflict = tot_dist - # Look for opposing paths at distance num_step-1 + + # Look for conflicting paths at distance num_step-1 elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0): conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position) for ca in conflicting_agent[0]: @@ -370,7 +379,8 @@ class TreeObsForRailEnv(ObservationBuilder): potential_conflict = tot_dist if self.env.dones[ca] and tot_dist < potential_conflict: potential_conflict = tot_dist - # Look for opposing paths at distance num_step+1 + + # Look for conflicting paths at distance num_step+1 elif int_position in np.delete(self.predicted_pos[post_step], handle, 0): conflicting_agent = np.where(self.predicted_pos[post_step] == int_position) for ca in conflicting_agent[0]: @@ -398,8 +408,10 @@ class TreeObsForRailEnv(ObservationBuilder): last_is_target = True break - cell_transitions = self.env.rail.get_transitions(*position, direction) - total_transitions = bin(self.env.rail.get_full_transitions(*position)).count("1") + # Check if crossing is found --> Not an unusable switch + if crossing_found: + # Treat the crossing as a straight rail cell + total_transitions = 2 num_transitions = np.count_nonzero(cell_transitions) exploring = False