Skip to content
Snippets Groups Projects
Commit f92415ee authored by spiglerg's avatar spiglerg
Browse files

added get_many() method to observation builders, base class and derived. Also...

added get_many() method to observation builders, base class and derived. Also removed the render() methods in the env base class and railenv
parent 6d48dcd1
No related branches found
No related tags found
No related merge requests found
......@@ -99,12 +99,6 @@ class Environment:
"""
raise NotImplementedError()
def render(self):
"""
Perform rendering of the environment.
"""
raise NotImplementedError()
def get_agent_handles(self):
"""
Returns a list of agents' handles to be used as keys in the step()
......
......@@ -30,6 +30,27 @@ class ObservationBuilder:
"""
raise NotImplementedError()
def get_many(self, handles=[]):
"""
Called whenever an observation has to be computed for the `env' environment, for each agent with handle
in the `handles' list.
Parameters
-------
handles : list of handles (optional)
List with the handles of the agents for which to compute the observation vector.
Returns
-------
function
A dictionary of observation structures, specific to the corresponding environment, with handles from
`handles' as keys.
"""
observations = {}
for h in handles:
observations[h] = self.get(h)
return observations
def get(self, handle=0):
"""
Called whenever an observation has to be computed for the `env' environment, possibly
......
......@@ -167,6 +167,19 @@ class TreeObsForRailEnv(ObservationBuilder):
elif movement == 3: # WEST
return (position[0], position[1] - 1)
def get_many(self, handles=[]):
"""
Called whenever an observation has to be computed for the `env' environment, for each agent with handle
in the `handles' list.
"""
# TODO: @Erik this is where the predictions should be computed, storing any temporary data inside this object.
observations = {}
for h in handles:
observations[h] = self.get(h)
return observations
def get(self, handle):
"""
Computes the current observation for agent `handle' in env
......
......@@ -330,10 +330,7 @@ class RailEnv(Environment):
return new_direction, transition_isValid
def _get_observations(self):
self.obs_dict = {}
self.debug_obs_dict = {}
for iAgent in range(self.get_num_agents()):
self.obs_dict[iAgent] = self.obs_builder.get(iAgent)
self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
return self.obs_dict
def _get_predictions(self):
......@@ -341,10 +338,6 @@ class RailEnv(Environment):
return {}
return {}
def render(self):
# TODO:
pass
def get_full_state_msg(self):
grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment