diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 0db0ab3b9538d1f84db199fe7cf4de5268a3d625..9acbeaf806258491d33297473469bb43bd0f4128 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -164,25 +164,25 @@ class TreeObsForRailEnv(ObservationBuilder): # Update local lookup table for all agents' positions # ignore other agents not in the grid (only status active and done) - self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if - agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]} + # self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if + # agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]} + + self.location_has_agent = {} + self.location_has_agent_direction = {} + self.location_has_agent_speed = {} + self.location_has_agent_malfunction = {} self.location_has_agent_ready_to_depart = {} - for agent in self.env.agents: - if agent.status == RailAgentStatus.READY_TO_DEPART: - self.location_has_agent_ready_to_depart[tuple(agent.initial_position)] = \ - self.location_has_agent_ready_to_depart.get(tuple(agent.initial_position), 0) + 1 - self.location_has_agent_direction = { - tuple(agent.position): agent.direction - for agent in self.env.agents if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] - } - self.location_has_agent_speed = { - tuple(agent.position): agent.speed_data['speed'] - for agent in self.env.agents if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] - } - self.location_has_agent_malfunction = { - tuple(agent.position): agent.malfunction_data['malfunction'] - for agent in self.env.agents if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] - } + + for _agent in self.env.agents: + if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]: + self.location_has_agent[tuple(_agent.position)] = 1 + self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction + self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed'] + self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data['malfunction'] + + if _agent.status in [RailAgentStatus.READY_TO_DEPART]: + self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \ + self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1 if handle > len(self.env.agents): print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents)) diff --git a/flatland/evaluators/service.py b/flatland/evaluators/service.py index d717f73e444fa277bb89acdc968c315c4e00b589..ef69fc53468d1570fa9f56a6dd4b6e1a8be31a53 100644 --- a/flatland/evaluators/service.py +++ b/flatland/evaluators/service.py @@ -41,6 +41,10 @@ m.patch() ######################################################## PER_STEP_TIMEOUT = 10 * 60 # 5 minutes RANDOM_SEED = int(os.getenv("FLATLAND_EVALUATION_RANDOM_SEED", 1001)) +SUPPORTED_CLIENT_VERSIONS = \ + [ + flatland.__version__ + ] class FlatlandRemoteEvaluationService: @@ -294,11 +298,6 @@ class FlatlandRemoteEvaluationService: _command_response = {} _command_response['type'] = messages.FLATLAND_RL.PONG _command_response['payload'] = {} - SUPPORTED_CLIENT_VERSIONS = \ - [ - flatland.__version__, - "2.1.5" - ] if client_version not in SUPPORTED_CLIENT_VERSIONS: _command_response['type'] = messages.FLATLAND_RL.ERROR _command_response['payload']['message'] = \