diff --git a/flatland/core/env.py b/flatland/core/env.py index 284afdffb6ce46ac481018af469c7d2e024fc792..5334b22f5762840f400743c80512bc6a118062ee 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -9,6 +9,10 @@ class Environment: """ Base interface for multi-agent environments in Flatland. + Derived environments should implement the following attributes: + action_space: tuple with the dimensions of the actions to be passed to the step method + observation_space: tuple with the dimensions of the observations returned by reset and step + Agents are identified by agent ids (handles). Examples: >>> obs = env.reset() @@ -39,6 +43,8 @@ class Environment: """ def __init__(self): + self.action_space = () + self.observation_space = () pass def reset(self): diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index 09a624e872200e29ede834d272f7de506d6de076..3cef545c1658e6bfe2a292ee26c3e665ce6a5abc 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -12,9 +12,13 @@ case of multi-agent environments. class ObservationBuilder: """ ObservationBuilder base class. + + Derived objects must implement and `observation_space' attribute as a tuple with the dimensuions of the returned + observations. """ def __init__(self): + self.observation_space = () pass def _set_env(self, env): diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index ba159cb7fb6df9fd171a3b509f352d3ba8d2d30c..af789f5d064f0e2b46b2328ebbfa55057bffdc14 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -19,6 +19,14 @@ class TreeObsForRailEnv(ObservationBuilder): def __init__(self, max_depth): self.max_depth = max_depth + # Compute the size of the returned observation vector + size = 0 + pow4 = 1 + for i in range(self.max_depth+1): + size += pow4 + pow4 *= 4 + self.observation_space = [size * 5] + def reset(self): agents = self.env.agents nAgents = len(agents) @@ -158,10 +166,6 @@ class TreeObsForRailEnv(ObservationBuilder): the transitions. The order is: [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back'] - - - - Each branch data is organized as: [root node information] + [recursive branch data from 'left'] + diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 74e7526caa93ca8a1821eb5b2a47576231eb95c3..4ac00f182ba75303e4a103a9bd16ee40d26ce9d0 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -112,6 +112,10 @@ class RailEnv(Environment): self.valid_positions = None + self.action_space = [1] + self.observation_space = self.obs_builder.observation_space # updated on resets? + + # no more agent_handles def get_agent_handles(self): return range(self.get_num_agents()) @@ -160,6 +164,7 @@ class RailEnv(Environment): # Reset the state of the observation builder with the new environment self.obs_builder.reset() + self.observation_space = self.obs_builder.observation_space # <-- change on reset? # Return the new observation vectors for each agent return self._get_observations()