Commit cc5a5b70 authored by Erik Nygren's avatar Erik Nygren
Browse files

Moved Prediction from Env to Observation. You can now pass your own predictor...

Moved Prediction from Env to Observation. You can now pass your own predictor to the tree observation builder. if no predictor is passed everything stays the same.
parent f92415ee
......@@ -17,7 +17,7 @@ class TreeObsForRailEnv(ObservationBuilder):
network to simplify the representation of the state of the environment for each agent.
"""
def __init__(self, max_depth):
def __init__(self, max_depth, predictor=None):
self.max_depth = max_depth
# Compute the size of the returned observation vector
......@@ -30,6 +30,7 @@ class TreeObsForRailEnv(ObservationBuilder):
self.observation_space = [size * self.observation_dim]
self.location_has_agent = {}
self.location_has_agent_direction = {}
self.predictor = predictor
self.agents_previous_reset = None
......@@ -174,7 +175,8 @@ class TreeObsForRailEnv(ObservationBuilder):
"""
# TODO: @Erik this is where the predictions should be computed, storing any temporary data inside this object.
if self.predictor:
print(self.predictor.get(0))
observations = {}
for h in handles:
observations[h] = self.get(h)
......@@ -537,6 +539,11 @@ class TreeObsForRailEnv(ObservationBuilder):
agent_data.extend(tmp_agent_data)
return tree_data, distance_data, agent_data
def _set_env(self, env):
self.env = env
if self.predictor:
self.predictor._set_env(self.env)
class GlobalObsForRailEnv(ObservationBuilder):
"""
......
......@@ -58,7 +58,6 @@ class RailEnv(Environment):
rail_generator=random_rail_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2),
prediction_builder_object=None
):
"""
Environment init.
......@@ -99,10 +98,6 @@ class RailEnv(Environment):
self.obs_builder = obs_builder_object
self.obs_builder._set_env(self)
self.prediction_builder = prediction_builder_object
if self.prediction_builder:
self.prediction_builder._set_env(self)
self.action_space = [1]
self.observation_space = self.obs_builder.observation_space # updated on resets?
......@@ -297,10 +292,6 @@ class RailEnv(Environment):
np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid
def predict(self):
if not self.prediction_builder:
return {}
return self.prediction_builder.get()
def check_action(self, agent, action):
transition_isValid = None
......@@ -333,10 +324,6 @@ class RailEnv(Environment):
self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
return self.obs_dict
def _get_predictions(self):
if not self.prediction_builder:
return {}
return {}
def get_full_state_msg(self):
grid_data = self.rail.grid.tolist()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment