diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 53630476874072c613bede5daf1ac76bdab33625..908df4322d58ec13729ce31cb2eebdb8ac0a74c2 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -263,6 +263,9 @@ class TreeObsForRailEnv(ObservationBuilder): other_agent_encountered = np.inf other_target_encountered = np.inf + other_agent_same_direction = 0 + other_agent_opposite_direction = 0 + num_steps = 1 while exploring: @@ -274,6 +277,14 @@ class TreeObsForRailEnv(ObservationBuilder): if num_steps < other_agent_encountered: other_agent_encountered = num_steps + if self.location_has_agent_direction[position] == direction: + # Cummulate the number of agents on branch with same direction + other_agent_same_direction += 1 + + if self.location_has_agent_direction[position] != direction: + # Cummulate the number of agents on branch with other direction + other_agent_opposite_direction += 1 + if position in self.location_has_target: if num_steps < other_target_encountered: other_target_encountered = num_steps @@ -366,10 +377,7 @@ class TreeObsForRailEnv(ObservationBuilder): other_agent_opposite_direction ] """ - other_agent_same_direction = \ - 1 if position in self.location_has_agent and self.location_has_agent_direction[position] == direction else 0 - other_agent_opposite_direction = \ - 1 if position in self.location_has_agent and self.location_has_agent_direction[position] != direction else 0 + if last_isTarget: observation = [0,