From 2df4a43402e5f4691fb777ab7306d6d4091d85be Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Wed, 2 Oct 2019 22:48:41 -0400 Subject: [PATCH] fixed bug in tree observation. --- examples/flatland_2_0_example.py | 4 ++-- flatland/envs/observations.py | 11 ++++------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 6453ccf7..40b41591 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -1,6 +1,6 @@ import numpy as np -from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv +from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import sparse_rail_generator @@ -39,7 +39,7 @@ env = RailEnv(width=40, schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=20, stochastic_data=stochastic_data, # Malfunction data generator - obs_builder_object=GlobalObsForRailEnv(), + obs_builder_object=TreeObservation, remove_agents_at_target=True ) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 7b7aacaf..1edf943b 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -174,7 +174,6 @@ class TreeObsForRailEnv(ObservationBuilder): else: self.location_has_agent_direction[(agent.position, agent.direction)] = 1 - self.location_has_agent_speed = {tuple(agent.position): agent.speed_data['speed'] for agent in self.env.agents} self.location_has_agent_malfunction = {tuple(agent.position): agent.malfunction_data['malfunction'] for agent in self.env.agents} @@ -271,9 +270,9 @@ class TreeObsForRailEnv(ObservationBuilder): if self.location_has_agent_malfunction[position] > malfunctioning_agent: malfunctioning_agent = self.location_has_agent_malfunction[position] - if (agent.position, agent.direction) in self.location_has_agent_direction: + if (position, direction) in self.location_has_agent_direction: # Cummulate the number of agents on branch with same direction - other_agent_same_direction += self.location_has_agent_direction[(agent.position, agent.direction)] + other_agent_same_direction += self.location_has_agent_direction[(position, direction)] # Check fractional speed of agents current_fractional_speed = self.location_has_agent_speed[position] @@ -284,13 +283,11 @@ class TreeObsForRailEnv(ObservationBuilder): # TODO: This does not work as expected yet other_agent_opposite_direction += self.location_has_agent[position] - \ self.location_has_agent_direction[ - (agent.position, agent.direction)] + (position, direction)] + else: # If no agent in the same direction was found all agents in that position are other direction other_agent_opposite_direction += self.location_has_agent[position] - print("went in here") - if other_agent_opposite_direction > 0: - print("Other agents", other_agent_opposite_direction) # Check number of possible transitions for agent and total number of transitions in cell (type) cell_transitions = self.env.rail.get_transitions(*position, direction) -- GitLab