Skip to content
Snippets Groups Projects
Commit de3e96a2 authored by Erik Nygren's avatar Erik Nygren
Browse files

initial changes to crossings. Now not detected anymore as unusable switches

parent bbfdae22
No related branches found
No related tags found
No related merge requests found
...@@ -347,6 +347,14 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -347,6 +347,14 @@ class TreeObsForRailEnv(ObservationBuilder):
# Cummulate the number of agents on branch with other direction # Cummulate the number of agents on branch with other direction
other_agent_opposite_direction += 1 other_agent_opposite_direction += 1
# Check number of possible transitions for agent and total number of transitions in cell (type)
cell_transitions = self.env.rail.get_transitions(*position, direction)
transition_bit = bin(self.env.rail.get_full_transitions(*position))
total_transitions = transition_bit.count("1")
crossing_found = False
if int(transition_bit, 2) == int('1000010000100001', 2):
crossing_found = True
# Register possible future conflict # Register possible future conflict
if self.predictor and num_steps < self.max_prediction_depth: if self.predictor and num_steps < self.max_prediction_depth:
int_position = coordinate_to_position(self.env.width, [position]) int_position = coordinate_to_position(self.env.width, [position])
...@@ -354,7 +362,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -354,7 +362,7 @@ class TreeObsForRailEnv(ObservationBuilder):
pre_step = max(0, tot_dist - 1) pre_step = max(0, tot_dist - 1)
post_step = min(self.max_prediction_depth - 1, tot_dist + 1) post_step = min(self.max_prediction_depth - 1, tot_dist + 1)
# Look for opposing paths at distance num_step # Look for conflicting paths at distance num_step
if int_position in np.delete(self.predicted_pos[tot_dist], handle, 0): if int_position in np.delete(self.predicted_pos[tot_dist], handle, 0):
conflicting_agent = np.where(self.predicted_pos[tot_dist] == int_position) conflicting_agent = np.where(self.predicted_pos[tot_dist] == int_position)
for ca in conflicting_agent[0]: for ca in conflicting_agent[0]:
...@@ -362,7 +370,8 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -362,7 +370,8 @@ class TreeObsForRailEnv(ObservationBuilder):
potential_conflict = tot_dist potential_conflict = tot_dist
if self.env.dones[ca] and tot_dist < potential_conflict: if self.env.dones[ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist potential_conflict = tot_dist
# Look for opposing paths at distance num_step-1
# Look for conflicting paths at distance num_step-1
elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0): elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0):
conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position) conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
for ca in conflicting_agent[0]: for ca in conflicting_agent[0]:
...@@ -370,7 +379,8 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -370,7 +379,8 @@ class TreeObsForRailEnv(ObservationBuilder):
potential_conflict = tot_dist potential_conflict = tot_dist
if self.env.dones[ca] and tot_dist < potential_conflict: if self.env.dones[ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist potential_conflict = tot_dist
# Look for opposing paths at distance num_step+1
# Look for conflicting paths at distance num_step+1
elif int_position in np.delete(self.predicted_pos[post_step], handle, 0): elif int_position in np.delete(self.predicted_pos[post_step], handle, 0):
conflicting_agent = np.where(self.predicted_pos[post_step] == int_position) conflicting_agent = np.where(self.predicted_pos[post_step] == int_position)
for ca in conflicting_agent[0]: for ca in conflicting_agent[0]:
...@@ -398,12 +408,8 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -398,12 +408,8 @@ class TreeObsForRailEnv(ObservationBuilder):
last_is_target = True last_is_target = True
break break
cell_transitions = self.env.rail.get_transitions(*position, direction)
transition_bit = bin(self.env.rail.get_full_transitions(*position))
total_transitions = transition_bit.count("1")
# Check if crossing is found --> Not an unusable switch # Check if crossing is found --> Not an unusable switch
if int(transition_bit, 2) == int('1000010000100001', 2): if crossing_found:
# Treat the crossing as a straight rail cell # Treat the crossing as a straight rail cell
total_transitions = 2 total_transitions = 2
num_transitions = np.count_nonzero(cell_transitions) num_transitions = np.count_nonzero(cell_transitions)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment