Skip to content
Snippets Groups Projects
Commit 8a8a7853 authored by spiglerg's avatar spiglerg
Browse files

action_space and observation_space, issue #46

parent fa54bee7
No related branches found
No related tags found
No related merge requests found
......@@ -9,6 +9,10 @@ class Environment:
"""
Base interface for multi-agent environments in Flatland.
Derived environments should implement the following attributes:
action_space: tuple with the dimensions of the actions to be passed to the step method
observation_space: tuple with the dimensions of the observations returned by reset and step
Agents are identified by agent ids (handles).
Examples:
>>> obs = env.reset()
......@@ -39,6 +43,8 @@ class Environment:
"""
def __init__(self):
self.action_space = ()
self.observation_space = ()
pass
def reset(self):
......
......@@ -12,9 +12,13 @@ case of multi-agent environments.
class ObservationBuilder:
"""
ObservationBuilder base class.
Derived objects must implement and `observation_space' attribute as a tuple with the dimensuions of the returned
observations.
"""
def __init__(self):
self.observation_space = ()
pass
def _set_env(self, env):
......
......@@ -19,6 +19,14 @@ class TreeObsForRailEnv(ObservationBuilder):
def __init__(self, max_depth):
self.max_depth = max_depth
# Compute the size of the returned observation vector
size = 0
pow4 = 1
for i in range(self.max_depth+1):
size += pow4
pow4 *= 4
self.observation_space = [size * 5]
def reset(self):
agents = self.env.agents
nAgents = len(agents)
......@@ -158,10 +166,6 @@ class TreeObsForRailEnv(ObservationBuilder):
the transitions. The order is:
[data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']
Each branch data is organized as:
[root node information] +
[recursive branch data from 'left'] +
......
......@@ -112,6 +112,10 @@ class RailEnv(Environment):
self.valid_positions = None
self.action_space = [1]
self.observation_space = self.obs_builder.observation_space # updated on resets?
# no more agent_handles
def get_agent_handles(self):
return range(self.get_num_agents())
......@@ -160,6 +164,7 @@ class RailEnv(Environment):
# Reset the state of the observation builder with the new environment
self.obs_builder.reset()
self.observation_space = self.obs_builder.observation_space # <-- change on reset?
# Return the new observation vectors for each agent
return self._get_observations()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment