Skip to content
Snippets Groups Projects
Commit dadd84be authored by Erik Nygren's avatar Erik Nygren
Browse files

Added visited list to observations.py Tree observation

parent 19fac966
No related branches found
No related tags found
No related merge requests found
......@@ -116,7 +116,7 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
for trials in range(1, n_trials + 1):
# Reset environment
obs = env.reset()
obs, _ = env.reset()
final_obs = obs.copy()
final_obs_next = obs.copy()
for a in range(env.get_num_agents()):
......@@ -148,7 +148,7 @@ for trials in range(1, n_trials + 1):
action_dict.update({a: action})
# Environment step
next_obs, all_rewards, done, _ = env.step(action_dict)
(next_obs,_), all_rewards, done, _ = env.step(action_dict)
for a in range(env.get_num_agents()):
data, distance = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5,
......
......@@ -219,7 +219,8 @@ class TreeObsForRailEnv(ObservationBuilder):
if possible_transitions[branch_direction]:
new_cell = self._new_position(agent.position, branch_direction)
branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1)
branch_observation, visited = self._explore_branch(handle, new_cell, branch_direction, root_observation,
1)
observation = observation + branch_observation
else:
num_cells_to_fill_in = 0
......@@ -228,7 +229,7 @@ class TreeObsForRailEnv(ObservationBuilder):
num_cells_to_fill_in += pow4
pow4 *= 4
observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
return observation
return observation, visited
def _explore_branch(self, handle, position, direction, root_observation, depth):
"""
......@@ -236,7 +237,7 @@ class TreeObsForRailEnv(ObservationBuilder):
"""
# [Recursive branch opened]
if depth >= self.max_depth + 1:
return []
return [], []
# Continue along direction until next switch or
# until no transitions are possible along the current direction (i.e., dead-ends)
......@@ -377,22 +378,24 @@ class TreeObsForRailEnv(ObservationBuilder):
# Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
# it back
new_cell = self._new_position(position, (branch_direction + 2) % 4)
branch_observation = self._explore_branch(handle,
new_cell,
(branch_direction + 2) % 4,
new_root_observation,
depth + 1)
branch_observation, branch_visited = self._explore_branch(handle,
new_cell,
(branch_direction + 2) % 4,
new_root_observation,
depth + 1)
observation = observation + branch_observation
if len(branch_visited) != 0:
visited.union(branch_visited)
elif last_isSwitch and possible_transitions[branch_direction]:
new_cell = self._new_position(position, branch_direction)
branch_observation = self._explore_branch(handle,
new_cell,
branch_direction,
new_root_observation,
depth + 1)
branch_observation, branch_visited = self._explore_branch(handle,
new_cell,
branch_direction,
new_root_observation,
depth + 1)
observation = observation + branch_observation
if len(branch_visited) != 0:
visited.union(branch_visited)
else:
num_cells_to_fill_in = 0
pow4 = 1
......@@ -401,7 +404,7 @@ class TreeObsForRailEnv(ObservationBuilder):
pow4 *= 4
observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
return observation
return observation, visited
def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0):
"""
......
......@@ -315,10 +315,11 @@ class RailEnv(Environment):
def _get_observations(self):
self.obs_dict = {}
self.debug_obs_dict = {}
# for handle in self.agents_handles:
for iAgent in range(self.get_num_agents()):
self.obs_dict[iAgent] = self.obs_builder.get(iAgent)
return self.obs_dict
self.obs_dict[iAgent], self.debug_obs_dict[iAgent] = self.obs_builder.get(iAgent)
return self.obs_dict, self.debug_obs_dict
def render(self):
# TODO:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment