From 5d186dc59797823578c4332f9d930983179ca8e6 Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Wed, 2 Oct 2019 20:14:30 -0400 Subject: [PATCH] still buggy with treeeobservation --- examples/flatland_2_0_example.py | 10 +++++----- flatland/envs/observations.py | 31 +++++++++++++++++++++++-------- 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 32b9f611..40b41591 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -28,13 +28,13 @@ speed_ration_map = {1.: 0.25, # Fast passenger train 1. / 3.: 0.25, # Slow commuter train 1. / 4.: 0.25} # Slow freight train -env = RailEnv(width=100, - height=20, - rail_generator=sparse_rail_generator(num_cities=2, # Number of cities in map (where train stations are) +env = RailEnv(width=40, + height=40, + rail_generator=sparse_rail_generator(num_cities=8, # Number of cities in map (where train stations are) seed=1, # Random seed - grid_mode=True, + grid_mode=False, max_inter_city_rails=2, - max_tracks_in_city=8, + max_tracks_in_city=4, ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=20, diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index d97c476f..7b7aacaf 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -161,13 +161,20 @@ class TreeObsForRailEnv(ObservationBuilder): # Update local lookup table for all agents' positions self.location_has_agent = dict() + self.location_has_agent_direction = dict() for agent in self.env.agents: if tuple(agent.position) in self.location_has_agent: self.location_has_agent[tuple(agent.position)] = self.location_has_agent[tuple(agent.position)] + 1 else: self.location_has_agent[tuple(agent.position)] = 1 - # TODO: Update this to handle number of agents at same location - self.location_has_agent_direction = {tuple(agent.position): agent.direction for agent in self.env.agents} + + if (agent.position, agent.direction) in self.location_has_agent_direction: + self.location_has_agent_direction[(agent.position, agent.direction)] = \ + self.location_has_agent_direction[(agent.position, agent.direction)] + 1 + else: + self.location_has_agent_direction[(agent.position, agent.direction)] = 1 + + self.location_has_agent_speed = {tuple(agent.position): agent.speed_data['speed'] for agent in self.env.agents} self.location_has_agent_malfunction = {tuple(agent.position): agent.malfunction_data['malfunction'] for agent in self.env.agents} @@ -264,20 +271,28 @@ class TreeObsForRailEnv(ObservationBuilder): if self.location_has_agent_malfunction[position] > malfunctioning_agent: malfunctioning_agent = self.location_has_agent_malfunction[position] - if self.location_has_agent_direction[position] == direction: + if (agent.position, agent.direction) in self.location_has_agent_direction: # Cummulate the number of agents on branch with same direction - other_agent_same_direction += 1 + other_agent_same_direction += self.location_has_agent_direction[(agent.position, agent.direction)] # Check fractional speed of agents current_fractional_speed = self.location_has_agent_speed[position] if current_fractional_speed < min_fractional_speed: min_fractional_speed = current_fractional_speed - if self.location_has_agent_direction[position] != direction: - # Cummulate the number of agents on branch with other direction - other_agent_opposite_direction += 1 + # Other direction agents + # TODO: This does not work as expected yet + other_agent_opposite_direction += self.location_has_agent[position] - \ + self.location_has_agent_direction[ + (agent.position, agent.direction)] + else: + # If no agent in the same direction was found all agents in that position are other direction + other_agent_opposite_direction += self.location_has_agent[position] + print("went in here") + if other_agent_opposite_direction > 0: + print("Other agents", other_agent_opposite_direction) - # Check number of possible transitions for agent and total number of transitions in cell (type) + # Check number of possible transitions for agent and total number of transitions in cell (type) cell_transitions = self.env.rail.get_transitions(*position, direction) transition_bit = bin(self.env.rail.get_full_transitions(*position)) total_transitions = transition_bit.count("1") -- GitLab