Skip to content
Snippets Groups Projects
Commit dadd84be authored by Erik Nygren's avatar Erik Nygren
Browse files

Added visited list to observations.py Tree observation

parent 19fac966
No related branches found
No related tags found
No related merge requests found
...@@ -116,7 +116,7 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1): ...@@ -116,7 +116,7 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
for trials in range(1, n_trials + 1): for trials in range(1, n_trials + 1):
# Reset environment # Reset environment
obs = env.reset() obs, _ = env.reset()
final_obs = obs.copy() final_obs = obs.copy()
final_obs_next = obs.copy() final_obs_next = obs.copy()
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
...@@ -148,7 +148,7 @@ for trials in range(1, n_trials + 1): ...@@ -148,7 +148,7 @@ for trials in range(1, n_trials + 1):
action_dict.update({a: action}) action_dict.update({a: action})
# Environment step # Environment step
next_obs, all_rewards, done, _ = env.step(action_dict) (next_obs,_), all_rewards, done, _ = env.step(action_dict)
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
data, distance = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5, data, distance = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5,
......
...@@ -219,7 +219,8 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -219,7 +219,8 @@ class TreeObsForRailEnv(ObservationBuilder):
if possible_transitions[branch_direction]: if possible_transitions[branch_direction]:
new_cell = self._new_position(agent.position, branch_direction) new_cell = self._new_position(agent.position, branch_direction)
branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1) branch_observation, visited = self._explore_branch(handle, new_cell, branch_direction, root_observation,
1)
observation = observation + branch_observation observation = observation + branch_observation
else: else:
num_cells_to_fill_in = 0 num_cells_to_fill_in = 0
...@@ -228,7 +229,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -228,7 +229,7 @@ class TreeObsForRailEnv(ObservationBuilder):
num_cells_to_fill_in += pow4 num_cells_to_fill_in += pow4
pow4 *= 4 pow4 *= 4
observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
return observation return observation, visited
def _explore_branch(self, handle, position, direction, root_observation, depth): def _explore_branch(self, handle, position, direction, root_observation, depth):
""" """
...@@ -236,7 +237,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -236,7 +237,7 @@ class TreeObsForRailEnv(ObservationBuilder):
""" """
# [Recursive branch opened] # [Recursive branch opened]
if depth >= self.max_depth + 1: if depth >= self.max_depth + 1:
return [] return [], []
# Continue along direction until next switch or # Continue along direction until next switch or
# until no transitions are possible along the current direction (i.e., dead-ends) # until no transitions are possible along the current direction (i.e., dead-ends)
...@@ -377,22 +378,24 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -377,22 +378,24 @@ class TreeObsForRailEnv(ObservationBuilder):
# Swap forward and back in case of dead-end, so that an agent can learn that going forward takes # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
# it back # it back
new_cell = self._new_position(position, (branch_direction + 2) % 4) new_cell = self._new_position(position, (branch_direction + 2) % 4)
branch_observation = self._explore_branch(handle, branch_observation, branch_visited = self._explore_branch(handle,
new_cell, new_cell,
(branch_direction + 2) % 4, (branch_direction + 2) % 4,
new_root_observation, new_root_observation,
depth + 1) depth + 1)
observation = observation + branch_observation observation = observation + branch_observation
if len(branch_visited) != 0:
visited.union(branch_visited)
elif last_isSwitch and possible_transitions[branch_direction]: elif last_isSwitch and possible_transitions[branch_direction]:
new_cell = self._new_position(position, branch_direction) new_cell = self._new_position(position, branch_direction)
branch_observation = self._explore_branch(handle, branch_observation, branch_visited = self._explore_branch(handle,
new_cell, new_cell,
branch_direction, branch_direction,
new_root_observation, new_root_observation,
depth + 1) depth + 1)
observation = observation + branch_observation observation = observation + branch_observation
if len(branch_visited) != 0:
visited.union(branch_visited)
else: else:
num_cells_to_fill_in = 0 num_cells_to_fill_in = 0
pow4 = 1 pow4 = 1
...@@ -401,7 +404,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -401,7 +404,7 @@ class TreeObsForRailEnv(ObservationBuilder):
pow4 *= 4 pow4 *= 4
observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
return observation return observation, visited
def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0): def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0):
""" """
......
...@@ -315,10 +315,11 @@ class RailEnv(Environment): ...@@ -315,10 +315,11 @@ class RailEnv(Environment):
def _get_observations(self): def _get_observations(self):
self.obs_dict = {} self.obs_dict = {}
self.debug_obs_dict = {}
# for handle in self.agents_handles: # for handle in self.agents_handles:
for iAgent in range(self.get_num_agents()): for iAgent in range(self.get_num_agents()):
self.obs_dict[iAgent] = self.obs_builder.get(iAgent) self.obs_dict[iAgent], self.debug_obs_dict[iAgent] = self.obs_builder.get(iAgent)
return self.obs_dict return self.obs_dict, self.debug_obs_dict
def render(self): def render(self):
# TODO: # TODO:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment