diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index f3c0a7c3c06ff27ecd956e1734a6fa33fd1236b3..a3d88d773db9edaa0777e2aee94593a0392a956c 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -17,7 +17,7 @@ class TreeObsForRailEnv(ObservationBuilder): network to simplify the representation of the state of the environment for each agent. """ - def __init__(self, max_depth): + def __init__(self, max_depth, predictor=None): self.max_depth = max_depth # Compute the size of the returned observation vector @@ -30,6 +30,7 @@ class TreeObsForRailEnv(ObservationBuilder): self.observation_space = [size * self.observation_dim] self.location_has_agent = {} self.location_has_agent_direction = {} + self.predictor = predictor self.agents_previous_reset = None @@ -174,7 +175,8 @@ class TreeObsForRailEnv(ObservationBuilder): """ # TODO: @Erik this is where the predictions should be computed, storing any temporary data inside this object. - + if self.predictor: + print(self.predictor.get(0)) observations = {} for h in handles: observations[h] = self.get(h) @@ -537,6 +539,11 @@ class TreeObsForRailEnv(ObservationBuilder): agent_data.extend(tmp_agent_data) return tree_data, distance_data, agent_data + def _set_env(self, env): + self.env = env + if self.predictor: + self.predictor._set_env(self.env) + class GlobalObsForRailEnv(ObservationBuilder): """ diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 44ed3f770600ba1aae1745926109deb1fa7398ef..5d20a5d9f38230f353b0a9616c49ede333206c49 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -58,7 +58,6 @@ class RailEnv(Environment): rail_generator=random_rail_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2), - prediction_builder_object=None ): """ Environment init. @@ -99,10 +98,6 @@ class RailEnv(Environment): self.obs_builder = obs_builder_object self.obs_builder._set_env(self) - self.prediction_builder = prediction_builder_object - if self.prediction_builder: - self.prediction_builder._set_env(self) - self.action_space = [1] self.observation_space = self.obs_builder.observation_space # updated on resets? @@ -297,10 +292,6 @@ class RailEnv(Environment): np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid - def predict(self): - if not self.prediction_builder: - return {} - return self.prediction_builder.get() def check_action(self, agent, action): transition_isValid = None @@ -333,10 +324,6 @@ class RailEnv(Environment): self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents()))) return self.obs_dict - def _get_predictions(self): - if not self.prediction_builder: - return {} - return {} def get_full_state_msg(self): grid_data = self.rail.grid.tolist()