diff --git a/flatland/core/env.py b/flatland/core/env.py index 3618d965a39b5a71fd1cf24aa81f2f876d5c6365..1bc5b6f3eba4ee4713bd3c8d6b88440006c215a5 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -84,21 +84,6 @@ class Environment: """ raise NotImplementedError() - def predict(self): - """ - Predictions step. - - Returns predictions for the agents. - The returns are dicts mapping from agent_id strings to values. - - Returns - ------- - predictions : dict - New predictions for each ready agent. - - """ - raise NotImplementedError() - def get_agent_handles(self): """ Returns a list of agents' handles to be used as keys in the step() diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index a3d88d773db9edaa0777e2aee94593a0392a956c..76bed8a46b7b906a79574fbc76e64296475fd6c1 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -173,10 +173,6 @@ class TreeObsForRailEnv(ObservationBuilder): Called whenever an observation has to be computed for the `env' environment, for each agent with handle in the `handles' list. """ - - # TODO: @Erik this is where the predictions should be computed, storing any temporary data inside this object. - if self.predictor: - print(self.predictor.get(0)) observations = {} for h in handles: observations[h] = self.get(h) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 5d20a5d9f38230f353b0a9616c49ede333206c49..7773f86c1407153f649c972398c9a58067a38947 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -292,7 +292,6 @@ class RailEnv(Environment): np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid - def check_action(self, agent, action): transition_isValid = None possible_transitions = self.rail.get_transitions((*agent.position, agent.direction)) @@ -324,7 +323,6 @@ class RailEnv(Environment): self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents()))) return self.obs_dict - def get_full_state_msg(self): grid_data = self.rail.grid.tolist() agent_static_data = [agent.to_list() for agent in self.agents_static] diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index d4e5c38e975fbbfd5d357b78d1cd868ac828701e..81565d62d3ab17e740a5e1b635a97fd97d01980f 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -323,7 +323,8 @@ class Controller(object): def restartAgents(self, event): self.log("Restart Agents - nAgents:", self.view.wRegenNAgents.value) if self.model.init_agents_static is not None: - self.model.env.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in self.model.init_agents_static] + self.model.env.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in + self.model.init_agents_static] self.model.env.agents = None self.model.init_agents_static = None self.player = None diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index bca964c960ee55d0d110e9c9e71d6b4481d93215..0f67421e397160a9d91085d6bb41568c618c6b6e 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -396,7 +396,7 @@ class PILSVG(PILGL): } # "paint" color of the train images we load - this is the color we will change. - # a3BaseColor = self.rgb_s2i("0091ea") + # a3BaseColor = self.rgb_s2i("0091ea") \# noqa: E800 # temporary workaround for trains / agents renamed with different colour: a3BaseColor = self.rgb_s2i("d50000") diff --git a/tests/test_env_prediction_builder.py b/tests/test_env_prediction_builder.py index 35a6a27b970ce54e1cabd3cf8c80d30a34800a25..be065d3d10a4bf99804ab1556dba563d9ef27406 100644 --- a/tests/test_env_prediction_builder.py +++ b/tests/test_env_prediction_builder.py @@ -5,7 +5,7 @@ import numpy as np from flatland.core.transition_map import GridTransitionMap, Grid4Transitions from flatland.envs.generators import rail_from_GridTransitionMap_generator -from flatland.envs.observations import GlobalObsForRailEnv +from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv from flatland.envs.rail_env import RailEnv @@ -64,8 +64,7 @@ def test_predictions(): height=rail_map.shape[0], rail_generator=rail_from_GridTransitionMap_generator(rail), number_of_agents=1, - obs_builder_object=GlobalObsForRailEnv(), - prediction_builder_object=DummyPredictorForRailEnv(max_depth=20) + obs_builder_object=TreeObsForRailEnv(max_depth=20, predictor=DummyPredictorForRailEnv(max_depth=20)), ) env.reset() @@ -74,7 +73,7 @@ def test_predictions(): env.agents[0].position = (5, 6) env.agents[0].direction = 0 - predictions = env.predict() + predictions = env.obs_builder.predictor.get() positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0]))) directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0]))) time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))