diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 8590922131d57cd76ba4a638b10e3e7e169a984a..83584a302f0e4064cef187171db291b637126242 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -295,21 +295,16 @@ class TreeObsForRailEnv(ObservationBuilder): if self.location_has_agent_direction[position] == direction: # Cummulate the number of agents on branch with same direction - other_agent_same_direction += self.location_has_agent_direction.get((position, direction), 0) + other_agent_same_direction += 1 # Check fractional speed of agents current_fractional_speed = self.location_has_agent_speed[position] if current_fractional_speed < min_fractional_speed: min_fractional_speed = current_fractional_speed - # Other direction agents - # TODO: Test that this behavior is as expected - other_agent_opposite_direction += \ - self.location_has_agent[position] - self.location_has_agent_direction.get((position, direction), - 0) - else: # If no agent in the same direction was found all agents in that position are other direction + # Attention this counts to many agents as a few might be going off on a switch. other_agent_opposite_direction += self.location_has_agent[position] # Check number of possible transitions for agent and total number of transitions in cell (type)