From f92415eedc114cce4aef7bf52f1829021f30762e Mon Sep 17 00:00:00 2001 From: Giacomo Spigler <spiglerg@gmail.com> Date: Wed, 12 Jun 2019 18:37:53 +0200 Subject: [PATCH] added get_many() method to observation builders, base class and derived. Also removed the render() methods in the env base class and railenv --- flatland/core/env.py | 6 ------ flatland/core/env_observation_builder.py | 21 +++++++++++++++++++++ flatland/envs/observations.py | 13 +++++++++++++ flatland/envs/rail_env.py | 9 +-------- 4 files changed, 35 insertions(+), 14 deletions(-) diff --git a/flatland/core/env.py b/flatland/core/env.py index 32691f50..3618d965 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -99,12 +99,6 @@ class Environment: """ raise NotImplementedError() - def render(self): - """ - Perform rendering of the environment. - """ - raise NotImplementedError() - def get_agent_handles(self): """ Returns a list of agents' handles to be used as keys in the step() diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index b30c2b1f..53e7a068 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -30,6 +30,27 @@ class ObservationBuilder: """ raise NotImplementedError() + def get_many(self, handles=[]): + """ + Called whenever an observation has to be computed for the `env' environment, for each agent with handle + in the `handles' list. + + Parameters + ------- + handles : list of handles (optional) + List with the handles of the agents for which to compute the observation vector. + + Returns + ------- + function + A dictionary of observation structures, specific to the corresponding environment, with handles from + `handles' as keys. + """ + observations = {} + for h in handles: + observations[h] = self.get(h) + return observations + def get(self, handle=0): """ Called whenever an observation has to be computed for the `env' environment, possibly diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 676051d8..f3c0a7c3 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -167,6 +167,19 @@ class TreeObsForRailEnv(ObservationBuilder): elif movement == 3: # WEST return (position[0], position[1] - 1) + def get_many(self, handles=[]): + """ + Called whenever an observation has to be computed for the `env' environment, for each agent with handle + in the `handles' list. + """ + + # TODO: @Erik this is where the predictions should be computed, storing any temporary data inside this object. + + observations = {} + for h in handles: + observations[h] = self.get(h) + return observations + def get(self, handle): """ Computes the current observation for agent `handle' in env diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 6cd64514..44ed3f77 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -330,10 +330,7 @@ class RailEnv(Environment): return new_direction, transition_isValid def _get_observations(self): - self.obs_dict = {} - self.debug_obs_dict = {} - for iAgent in range(self.get_num_agents()): - self.obs_dict[iAgent] = self.obs_builder.get(iAgent) + self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents()))) return self.obs_dict def _get_predictions(self): @@ -341,10 +338,6 @@ class RailEnv(Environment): return {} return {} - def render(self): - # TODO: - pass - def get_full_state_msg(self): grid_data = self.rail.grid.tolist() agent_static_data = [agent.to_list() for agent in self.agents_static] -- GitLab