From 2be1a6c4db0587b0bfb191fe9f94da4a31825f86 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Thu, 13 Jun 2019 08:47:43 +0200 Subject: [PATCH] fix master --- flatland/core/env.py | 15 --------------- flatland/envs/observations.py | 4 ---- flatland/envs/rail_env.py | 2 -- flatland/utils/editor.py | 3 ++- flatland/utils/graphics_pil.py | 2 +- tests/test_env_prediction_builder.py | 7 +++---- 6 files changed, 6 insertions(+), 27 deletions(-) diff --git a/flatland/core/env.py b/flatland/core/env.py index 3618d965..1bc5b6f3 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -84,21 +84,6 @@ class Environment: """ raise NotImplementedError() - def predict(self): - """ - Predictions step. - - Returns predictions for the agents. - The returns are dicts mapping from agent_id strings to values. - - Returns - ------- - predictions : dict - New predictions for each ready agent. - - """ - raise NotImplementedError() - def get_agent_handles(self): """ Returns a list of agents' handles to be used as keys in the step() diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index a3d88d77..76bed8a4 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -173,10 +173,6 @@ class TreeObsForRailEnv(ObservationBuilder): Called whenever an observation has to be computed for the `env' environment, for each agent with handle in the `handles' list. """ - - # TODO: @Erik this is where the predictions should be computed, storing any temporary data inside this object. - if self.predictor: - print(self.predictor.get(0)) observations = {} for h in handles: observations[h] = self.get(h) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 5d20a5d9..7773f86c 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -292,7 +292,6 @@ class RailEnv(Environment): np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid - def check_action(self, agent, action): transition_isValid = None possible_transitions = self.rail.get_transitions((*agent.position, agent.direction)) @@ -324,7 +323,6 @@ class RailEnv(Environment): self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents()))) return self.obs_dict - def get_full_state_msg(self): grid_data = self.rail.grid.tolist() agent_static_data = [agent.to_list() for agent in self.agents_static] diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index d4e5c38e..81565d62 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -323,7 +323,8 @@ class Controller(object): def restartAgents(self, event): self.log("Restart Agents - nAgents:", self.view.wRegenNAgents.value) if self.model.init_agents_static is not None: - self.model.env.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in self.model.init_agents_static] + self.model.env.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in + self.model.init_agents_static] self.model.env.agents = None self.model.init_agents_static = None self.player = None diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index bca964c9..0f67421e 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -396,7 +396,7 @@ class PILSVG(PILGL): } # "paint" color of the train images we load - this is the color we will change. - # a3BaseColor = self.rgb_s2i("0091ea") + # a3BaseColor = self.rgb_s2i("0091ea") \# noqa: E800 # temporary workaround for trains / agents renamed with different colour: a3BaseColor = self.rgb_s2i("d50000") diff --git a/tests/test_env_prediction_builder.py b/tests/test_env_prediction_builder.py index 35a6a27b..be065d3d 100644 --- a/tests/test_env_prediction_builder.py +++ b/tests/test_env_prediction_builder.py @@ -5,7 +5,7 @@ import numpy as np from flatland.core.transition_map import GridTransitionMap, Grid4Transitions from flatland.envs.generators import rail_from_GridTransitionMap_generator -from flatland.envs.observations import GlobalObsForRailEnv +from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv from flatland.envs.rail_env import RailEnv @@ -64,8 +64,7 @@ def test_predictions(): height=rail_map.shape[0], rail_generator=rail_from_GridTransitionMap_generator(rail), number_of_agents=1, - obs_builder_object=GlobalObsForRailEnv(), - prediction_builder_object=DummyPredictorForRailEnv(max_depth=20) + obs_builder_object=TreeObsForRailEnv(max_depth=20, predictor=DummyPredictorForRailEnv(max_depth=20)), ) env.reset() @@ -74,7 +73,7 @@ def test_predictions(): env.agents[0].position = (5, 6) env.agents[0].direction = 0 - predictions = env.predict() + predictions = env.obs_builder.predictor.get() positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0]))) directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0]))) time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0]))) -- GitLab