Skip to content
Snippets Groups Projects
Commit 8a8a7853 authored by spiglerg's avatar spiglerg
Browse files

action_space and observation_space, issue #46

parent fa54bee7
No related branches found
No related tags found
No related merge requests found
...@@ -9,6 +9,10 @@ class Environment: ...@@ -9,6 +9,10 @@ class Environment:
""" """
Base interface for multi-agent environments in Flatland. Base interface for multi-agent environments in Flatland.
Derived environments should implement the following attributes:
action_space: tuple with the dimensions of the actions to be passed to the step method
observation_space: tuple with the dimensions of the observations returned by reset and step
Agents are identified by agent ids (handles). Agents are identified by agent ids (handles).
Examples: Examples:
>>> obs = env.reset() >>> obs = env.reset()
...@@ -39,6 +43,8 @@ class Environment: ...@@ -39,6 +43,8 @@ class Environment:
""" """
def __init__(self): def __init__(self):
self.action_space = ()
self.observation_space = ()
pass pass
def reset(self): def reset(self):
......
...@@ -12,9 +12,13 @@ case of multi-agent environments. ...@@ -12,9 +12,13 @@ case of multi-agent environments.
class ObservationBuilder: class ObservationBuilder:
""" """
ObservationBuilder base class. ObservationBuilder base class.
Derived objects must implement and `observation_space' attribute as a tuple with the dimensuions of the returned
observations.
""" """
def __init__(self): def __init__(self):
self.observation_space = ()
pass pass
def _set_env(self, env): def _set_env(self, env):
......
...@@ -19,6 +19,14 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -19,6 +19,14 @@ class TreeObsForRailEnv(ObservationBuilder):
def __init__(self, max_depth): def __init__(self, max_depth):
self.max_depth = max_depth self.max_depth = max_depth
# Compute the size of the returned observation vector
size = 0
pow4 = 1
for i in range(self.max_depth+1):
size += pow4
pow4 *= 4
self.observation_space = [size * 5]
def reset(self): def reset(self):
agents = self.env.agents agents = self.env.agents
nAgents = len(agents) nAgents = len(agents)
...@@ -158,10 +166,6 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -158,10 +166,6 @@ class TreeObsForRailEnv(ObservationBuilder):
the transitions. The order is: the transitions. The order is:
[data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back'] [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']
Each branch data is organized as: Each branch data is organized as:
[root node information] + [root node information] +
[recursive branch data from 'left'] + [recursive branch data from 'left'] +
......
...@@ -112,6 +112,10 @@ class RailEnv(Environment): ...@@ -112,6 +112,10 @@ class RailEnv(Environment):
self.valid_positions = None self.valid_positions = None
self.action_space = [1]
self.observation_space = self.obs_builder.observation_space # updated on resets?
# no more agent_handles # no more agent_handles
def get_agent_handles(self): def get_agent_handles(self):
return range(self.get_num_agents()) return range(self.get_num_agents())
...@@ -160,6 +164,7 @@ class RailEnv(Environment): ...@@ -160,6 +164,7 @@ class RailEnv(Environment):
# Reset the state of the observation builder with the new environment # Reset the state of the observation builder with the new environment
self.obs_builder.reset() self.obs_builder.reset()
self.observation_space = self.obs_builder.observation_space # <-- change on reset?
# Return the new observation vectors for each agent # Return the new observation vectors for each agent
return self._get_observations() return self._get_observations()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment