diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 6453ccf78252b0ca043c8280ad204de151bf5cf4..40b415915305df0384b31ad30037ec14ec0e985d 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -1,6 +1,6 @@ import numpy as np -from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv +from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import sparse_rail_generator @@ -39,7 +39,7 @@ env = RailEnv(width=40, schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=20, stochastic_data=stochastic_data, # Malfunction data generator - obs_builder_object=GlobalObsForRailEnv(), + obs_builder_object=TreeObservation, remove_agents_at_target=True ) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 7b7aacaf3c7ef2a7c3cb4bccdf48c8ede1f93e03..1edf943b4f015e3881ab0420bf69f3f21ac06456 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -174,7 +174,6 @@ class TreeObsForRailEnv(ObservationBuilder): else: self.location_has_agent_direction[(agent.position, agent.direction)] = 1 - self.location_has_agent_speed = {tuple(agent.position): agent.speed_data['speed'] for agent in self.env.agents} self.location_has_agent_malfunction = {tuple(agent.position): agent.malfunction_data['malfunction'] for agent in self.env.agents} @@ -271,9 +270,9 @@ class TreeObsForRailEnv(ObservationBuilder): if self.location_has_agent_malfunction[position] > malfunctioning_agent: malfunctioning_agent = self.location_has_agent_malfunction[position] - if (agent.position, agent.direction) in self.location_has_agent_direction: + if (position, direction) in self.location_has_agent_direction: # Cummulate the number of agents on branch with same direction - other_agent_same_direction += self.location_has_agent_direction[(agent.position, agent.direction)] + other_agent_same_direction += self.location_has_agent_direction[(position, direction)] # Check fractional speed of agents current_fractional_speed = self.location_has_agent_speed[position] @@ -284,13 +283,11 @@ class TreeObsForRailEnv(ObservationBuilder): # TODO: This does not work as expected yet other_agent_opposite_direction += self.location_has_agent[position] - \ self.location_has_agent_direction[ - (agent.position, agent.direction)] + (position, direction)] + else: # If no agent in the same direction was found all agents in that position are other direction other_agent_opposite_direction += self.location_has_agent[position] - print("went in here") - if other_agent_opposite_direction > 0: - print("Other agents", other_agent_opposite_direction) # Check number of possible transitions for agent and total number of transitions in cell (type) cell_transitions = self.env.rail.get_transitions(*position, direction)