Commit f92415ee authored by spiglerg's avatar spiglerg
Browse files

added get_many() method to observation builders, base class and derived. Also...

added get_many() method to observation builders, base class and derived. Also removed the render() methods in the env base class and railenv
parent 6d48dcd1
Pipeline #1049 failed with stage
in 8 minutes and 7 seconds
......@@ -99,12 +99,6 @@ class Environment:
"""
raise NotImplementedError()
def render(self):
"""
Perform rendering of the environment.
"""
raise NotImplementedError()
def get_agent_handles(self):
"""
Returns a list of agents' handles to be used as keys in the step()
......
......@@ -30,6 +30,27 @@ class ObservationBuilder:
"""
raise NotImplementedError()
def get_many(self, handles=[]):
"""
Called whenever an observation has to be computed for the `env' environment, for each agent with handle
in the `handles' list.
Parameters
-------
handles : list of handles (optional)
List with the handles of the agents for which to compute the observation vector.
Returns
-------
function
A dictionary of observation structures, specific to the corresponding environment, with handles from
`handles' as keys.
"""
observations = {}
for h in handles:
observations[h] = self.get(h)
return observations
def get(self, handle=0):
"""
Called whenever an observation has to be computed for the `env' environment, possibly
......
......@@ -167,6 +167,19 @@ class TreeObsForRailEnv(ObservationBuilder):
elif movement == 3: # WEST
return (position[0], position[1] - 1)
def get_many(self, handles=[]):
"""
Called whenever an observation has to be computed for the `env' environment, for each agent with handle
in the `handles' list.
"""
# TODO: @Erik this is where the predictions should be computed, storing any temporary data inside this object.
observations = {}
for h in handles:
observations[h] = self.get(h)
return observations
def get(self, handle):
"""
Computes the current observation for agent `handle' in env
......
......@@ -330,10 +330,7 @@ class RailEnv(Environment):
return new_direction, transition_isValid
def _get_observations(self):
self.obs_dict = {}
self.debug_obs_dict = {}
for iAgent in range(self.get_num_agents()):
self.obs_dict[iAgent] = self.obs_builder.get(iAgent)
self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
return self.obs_dict
def _get_predictions(self):
......@@ -341,10 +338,6 @@ class RailEnv(Environment):
return {}
return {}
def render(self):
# TODO:
pass
def get_full_state_msg(self):
grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment