diff --git a/flatland/core/env.py b/flatland/core/env.py index 32691f507f4cb5586f10b5645cc22ece718edc21..1bc5b6f3eba4ee4713bd3c8d6b88440006c215a5 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -84,27 +84,6 @@ class Environment: """ raise NotImplementedError() - def predict(self): - """ - Predictions step. - - Returns predictions for the agents. - The returns are dicts mapping from agent_id strings to values. - - Returns - ------- - predictions : dict - New predictions for each ready agent. - - """ - raise NotImplementedError() - - def render(self): - """ - Perform rendering of the environment. - """ - raise NotImplementedError() - def get_agent_handles(self): """ Returns a list of agents' handles to be used as keys in the step() diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index b30c2b1f5ddab079c9b6c41e35f03c69ed4162c3..53e7a068b73f9907217777251bce0fdd704603be 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -30,6 +30,27 @@ class ObservationBuilder: """ raise NotImplementedError() + def get_many(self, handles=[]): + """ + Called whenever an observation has to be computed for the `env' environment, for each agent with handle + in the `handles' list. + + Parameters + ------- + handles : list of handles (optional) + List with the handles of the agents for which to compute the observation vector. + + Returns + ------- + function + A dictionary of observation structures, specific to the corresponding environment, with handles from + `handles' as keys. + """ + observations = {} + for h in handles: + observations[h] = self.get(h) + return observations + def get(self, handle=0): """ Called whenever an observation has to be computed for the `env' environment, possibly diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 676051d8338534c704ef98c90bee08a2836d4cfb..541f8ad592d1481afb8eb6da2eb7b887aacae419 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -17,7 +17,7 @@ class TreeObsForRailEnv(ObservationBuilder): network to simplify the representation of the state of the environment for each agent. """ - def __init__(self, max_depth): + def __init__(self, max_depth, predictor=None): self.max_depth = max_depth # Compute the size of the returned observation vector @@ -30,7 +30,7 @@ class TreeObsForRailEnv(ObservationBuilder): self.observation_space = [size * self.observation_dim] self.location_has_agent = {} self.location_has_agent_direction = {} - + self.predictor = predictor self.agents_previous_reset = None def reset(self): @@ -167,6 +167,21 @@ class TreeObsForRailEnv(ObservationBuilder): elif movement == 3: # WEST return (position[0], position[1] - 1) + def get_many(self, handles=[]): + """ + Called whenever an observation has to be computed for the `env' environment, for each agent with handle + in the `handles' list. + """ + + self.predictions = [] + if self.predictor: + for a in range(len(handles)): + self.predictions.append(self.predictor.get(a)) + observations = {} + for h in handles: + observations[h] = self.get(h) + return observations + def get(self, handle): """ Computes the current observation for agent `handle' in env @@ -207,6 +222,8 @@ class TreeObsForRailEnv(ObservationBuilder): (possible future use: number of other agents in other direction in this branch, ie. number of conflicts) 0 = no agent present other direction than myself + #8: possible conflict detected + Missing/padding nodes are filled in with -inf (truncated). Missing values in present node are filled in with +inf (truncated). @@ -241,7 +258,6 @@ class TreeObsForRailEnv(ObservationBuilder): for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]: if possible_transitions[branch_direction]: new_cell = self._new_position(agent.position, branch_direction) - branch_observation, branch_visited = \ self._explore_branch(handle, new_cell, branch_direction, root_observation, 1) observation = observation + branch_observation @@ -524,6 +540,11 @@ class TreeObsForRailEnv(ObservationBuilder): agent_data.extend(tmp_agent_data) return tree_data, distance_data, agent_data + def _set_env(self, env): + self.env = env + if self.predictor: + self.predictor._set_env(self.env) + class GlobalObsForRailEnv(ObservationBuilder): """ diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 795aabba67bf0deaf3d73c69a74788b9527abb58..c22e1c5120b54a170f9c59bb54c7666ca910f086 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -58,7 +58,6 @@ class RailEnv(Environment): rail_generator=random_rail_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2), - prediction_builder_object=None ): """ Environment init. @@ -99,10 +98,6 @@ class RailEnv(Environment): self.obs_builder = obs_builder_object self.obs_builder._set_env(self) - self.prediction_builder = prediction_builder_object - if self.prediction_builder: - self.prediction_builder._set_env(self) - self.action_space = [1] self.observation_space = self.obs_builder.observation_space # updated on resets? @@ -297,11 +292,6 @@ class RailEnv(Environment): np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid - def predict(self): - if not self.prediction_builder: - return {} - return self.prediction_builder.get() - def check_action(self, agent, action): transition_isValid = None possible_transitions = self.rail.get_transitions((*agent.position, agent.direction)) @@ -330,21 +320,9 @@ class RailEnv(Environment): return new_direction, transition_isValid def _get_observations(self): - self.obs_dict = {} - self.debug_obs_dict = {} - for iAgent in range(self.get_num_agents()): - self.obs_dict[iAgent] = self.obs_builder.get(iAgent) + self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents()))) return self.obs_dict - def _get_predictions(self): - if not self.prediction_builder: - return {} - return {} - - def render(self): - # TODO: - pass - def get_full_state_msg(self): grid_data = self.rail.grid.tolist() agent_static_data = [agent.to_list() for agent in self.agents_static] diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index be804c0348cf0fca9cc0c089c64693f7a617064b..8e934904f86d2da0dae6d59038a2e0b499415271 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -383,7 +383,9 @@ class PILSVG(PILGL): (0, 3): "Zug_2_Weiche_#0091ea.svg" } - # "paint" color of the train images we load + # "paint" color of the train images we load - this is the color we will change. + # a3BaseColor = self.rgb_s2i("0091ea") \# noqa: E800 + # temporary workaround for trains / agents renamed with different colour: a3BaseColor = self.rgb_s2i("d50000") self.dPilZug = {} diff --git a/tests/test_env_prediction_builder.py b/tests/test_env_prediction_builder.py index 35a6a27b970ce54e1cabd3cf8c80d30a34800a25..a1c951d31f7121a8445f4309c31ce58653c7e463 100644 --- a/tests/test_env_prediction_builder.py +++ b/tests/test_env_prediction_builder.py @@ -5,7 +5,7 @@ import numpy as np from flatland.core.transition_map import GridTransitionMap, Grid4Transitions from flatland.envs.generators import rail_from_GridTransitionMap_generator -from flatland.envs.observations import GlobalObsForRailEnv +from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv from flatland.envs.rail_env import RailEnv @@ -64,8 +64,7 @@ def test_predictions(): height=rail_map.shape[0], rail_generator=rail_from_GridTransitionMap_generator(rail), number_of_agents=1, - obs_builder_object=GlobalObsForRailEnv(), - prediction_builder_object=DummyPredictorForRailEnv(max_depth=20) + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=20)), ) env.reset() @@ -74,7 +73,7 @@ def test_predictions(): env.agents[0].position = (5, 6) env.agents[0].direction = 0 - predictions = env.predict() + predictions = env.obs_builder.predictor.get() positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0]))) directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0]))) time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))