From 8a8a7853c2df2b617fed29fe4a738bc9c6567cd2 Mon Sep 17 00:00:00 2001 From: Giacomo Spigler <spiglerg@gmail.com> Date: Thu, 23 May 2019 17:16:11 +0200 Subject: [PATCH] action_space and observation_space, issue #46 --- flatland/core/env.py | 6 ++++++ flatland/core/env_observation_builder.py | 4 ++++ flatland/envs/observations.py | 12 ++++++++---- flatland/envs/rail_env.py | 5 +++++ 4 files changed, 23 insertions(+), 4 deletions(-) diff --git a/flatland/core/env.py b/flatland/core/env.py index 284afdff..5334b22f 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -9,6 +9,10 @@ class Environment: """ Base interface for multi-agent environments in Flatland. + Derived environments should implement the following attributes: + action_space: tuple with the dimensions of the actions to be passed to the step method + observation_space: tuple with the dimensions of the observations returned by reset and step + Agents are identified by agent ids (handles). Examples: >>> obs = env.reset() @@ -39,6 +43,8 @@ class Environment: """ def __init__(self): + self.action_space = () + self.observation_space = () pass def reset(self): diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index 09a624e8..3cef545c 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -12,9 +12,13 @@ case of multi-agent environments. class ObservationBuilder: """ ObservationBuilder base class. + + Derived objects must implement and `observation_space' attribute as a tuple with the dimensuions of the returned + observations. """ def __init__(self): + self.observation_space = () pass def _set_env(self, env): diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index ba159cb7..af789f5d 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -19,6 +19,14 @@ class TreeObsForRailEnv(ObservationBuilder): def __init__(self, max_depth): self.max_depth = max_depth + # Compute the size of the returned observation vector + size = 0 + pow4 = 1 + for i in range(self.max_depth+1): + size += pow4 + pow4 *= 4 + self.observation_space = [size * 5] + def reset(self): agents = self.env.agents nAgents = len(agents) @@ -158,10 +166,6 @@ class TreeObsForRailEnv(ObservationBuilder): the transitions. The order is: [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back'] - - - - Each branch data is organized as: [root node information] + [recursive branch data from 'left'] + diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 74e7526c..4ac00f18 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -112,6 +112,10 @@ class RailEnv(Environment): self.valid_positions = None + self.action_space = [1] + self.observation_space = self.obs_builder.observation_space # updated on resets? + + # no more agent_handles def get_agent_handles(self): return range(self.get_num_agents()) @@ -160,6 +164,7 @@ class RailEnv(Environment): # Reset the state of the observation builder with the new environment self.obs_builder.reset() + self.observation_space = self.obs_builder.observation_space # <-- change on reset? # Return the new observation vectors for each agent return self._get_observations() -- GitLab