Skip to content
Snippets Groups Projects
Commit 725b98de authored by Erik Nygren's avatar Erik Nygren
Browse files

fixed bugs in tree observation

parent 4bef7d4f
No related branches found
No related tags found
No related merge requests found
...@@ -81,7 +81,6 @@ for trials in range(1, n_trials + 1): ...@@ -81,7 +81,6 @@ for trials in range(1, n_trials + 1):
# reward and whether their are done # reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict) next_obs, all_rewards, done, _ = env.step(action_dict)
TreeObservation.util_print_obs_subtree(next_obs[0], num_features_per_node=8) TreeObservation.util_print_obs_subtree(next_obs[0], num_features_per_node=8)
print(len(next_obs[0]))
# Update replay buffer and train agent # Update replay buffer and train agent
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])) agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
......
...@@ -212,11 +212,11 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -212,11 +212,11 @@ class TreeObsForRailEnv(ObservationBuilder):
Finally, each node information is composed of 5 floating point values: Finally, each node information is composed of 5 floating point values:
#1: #1: 1 if own target lies on the explored branch
#2: 1 if a target of another agent is detected between the previous node and the current one. #2: distance toa target of another agent is detected between the previous node and the current one.
#3: 1 if another agent is detected between the previous node and the current one. #3: distance to another agent is detected between the previous node and the current one.
#4: distance of agent to the current branch node #4: distance of agent to the current branch node
...@@ -255,6 +255,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -255,6 +255,7 @@ class TreeObsForRailEnv(ObservationBuilder):
agent = self.env.agents[handle] # TODO: handle being treated as index agent = self.env.agents[handle] # TODO: handle being treated as index
possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction)) possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
num_transitions = np.count_nonzero(possible_transitions) num_transitions = np.count_nonzero(possible_transitions)
# Root node - current position # Root node - current position
observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0, 0] observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0, 0]
...@@ -304,7 +305,8 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -304,7 +305,8 @@ class TreeObsForRailEnv(ObservationBuilder):
last_isTarget = False last_isTarget = False
visited = set() visited = set()
agent = self.env.agents[handle]
own_target_encountered = np.inf
other_agent_encountered = np.inf other_agent_encountered = np.inf
other_target_encountered = np.inf other_target_encountered = np.inf
other_agent_same_direction = 0 other_agent_same_direction = 0
...@@ -334,9 +336,14 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -334,9 +336,14 @@ class TreeObsForRailEnv(ObservationBuilder):
handle): handle):
potential_conflict = 1 potential_conflict = 1
if position in self.location_has_target: if position in self.location_has_target and position != agent.target:
if num_steps < other_target_encountered: if num_steps < other_target_encountered:
other_target_encountered = num_steps other_target_encountered = num_steps
if position == agent.target:
if num_steps < own_target_encountered:
own_target_encountered = num_steps
# ############################# # #############################
# ############################# # #############################
...@@ -428,7 +435,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -428,7 +435,7 @@ class TreeObsForRailEnv(ObservationBuilder):
""" """
if last_isTarget: if last_isTarget:
observation = [0, observation = [own_target_encountered,
other_target_encountered, other_target_encountered,
other_agent_encountered, other_agent_encountered,
root_observation[3] + num_steps, root_observation[3] + num_steps,
...@@ -439,7 +446,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -439,7 +446,7 @@ class TreeObsForRailEnv(ObservationBuilder):
] ]
elif last_isTerminal: elif last_isTerminal:
observation = [0, observation = [own_target_encountered,
other_target_encountered, other_target_encountered,
other_agent_encountered, other_agent_encountered,
np.inf, np.inf,
...@@ -449,7 +456,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -449,7 +456,7 @@ class TreeObsForRailEnv(ObservationBuilder):
potential_conflict potential_conflict
] ]
else: else:
observation = [0, observation = [own_target_encountered,
other_target_encountered, other_target_encountered,
other_agent_encountered, other_agent_encountered,
root_observation[3] + num_steps, root_observation[3] + num_steps,
...@@ -549,8 +556,11 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -549,8 +556,11 @@ class TreeObsForRailEnv(ObservationBuilder):
pow4 *= 4 pow4 *= 4
child_size = (len(tree) - num_features_per_node) // 4 child_size = (len(tree) - num_features_per_node) // 4
tree_data = tree[:4].tolist() tree_data = tree[:4].tolist()
# print("data",tree_data)
distance_data = [tree[4]] distance_data = [tree[4]]
#print("distance",distance_data)
agent_data = tree[5:num_features_per_node].tolist() agent_data = tree[5:num_features_per_node].tolist()
#print("agent_data",agent_data)
for children in range(4): for children in range(4):
child_tree = tree[(num_features_per_node + children * child_size): child_tree = tree[(num_features_per_node + children * child_size):
(num_features_per_node + (children + 1) * child_size)] (num_features_per_node + (children + 1) * child_size)]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment