Commit 725b98de authored by Erik Nygren's avatar Erik Nygren
Browse files

fixed bugs in tree observation

parent 4bef7d4f
Pipeline #1086 failed with stages
in 9 minutes and 24 seconds
......@@ -81,7 +81,6 @@ for trials in range(1, n_trials + 1):
# reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict)
TreeObservation.util_print_obs_subtree(next_obs[0], num_features_per_node=8)
print(len(next_obs[0]))
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
......
......@@ -212,11 +212,11 @@ class TreeObsForRailEnv(ObservationBuilder):
Finally, each node information is composed of 5 floating point values:
#1:
#1: 1 if own target lies on the explored branch
#2: 1 if a target of another agent is detected between the previous node and the current one.
#2: distance toa target of another agent is detected between the previous node and the current one.
#3: 1 if another agent is detected between the previous node and the current one.
#3: distance to another agent is detected between the previous node and the current one.
#4: distance of agent to the current branch node
......@@ -255,6 +255,7 @@ class TreeObsForRailEnv(ObservationBuilder):
agent = self.env.agents[handle] # TODO: handle being treated as index
possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
num_transitions = np.count_nonzero(possible_transitions)
# Root node - current position
observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0, 0]
......@@ -304,7 +305,8 @@ class TreeObsForRailEnv(ObservationBuilder):
last_isTarget = False
visited = set()
agent = self.env.agents[handle]
own_target_encountered = np.inf
other_agent_encountered = np.inf
other_target_encountered = np.inf
other_agent_same_direction = 0
......@@ -334,9 +336,14 @@ class TreeObsForRailEnv(ObservationBuilder):
handle):
potential_conflict = 1
if position in self.location_has_target:
if position in self.location_has_target and position != agent.target:
if num_steps < other_target_encountered:
other_target_encountered = num_steps
if position == agent.target:
if num_steps < own_target_encountered:
own_target_encountered = num_steps
# #############################
# #############################
......@@ -428,7 +435,7 @@ class TreeObsForRailEnv(ObservationBuilder):
"""
if last_isTarget:
observation = [0,
observation = [own_target_encountered,
other_target_encountered,
other_agent_encountered,
root_observation[3] + num_steps,
......@@ -439,7 +446,7 @@ class TreeObsForRailEnv(ObservationBuilder):
]
elif last_isTerminal:
observation = [0,
observation = [own_target_encountered,
other_target_encountered,
other_agent_encountered,
np.inf,
......@@ -449,7 +456,7 @@ class TreeObsForRailEnv(ObservationBuilder):
potential_conflict
]
else:
observation = [0,
observation = [own_target_encountered,
other_target_encountered,
other_agent_encountered,
root_observation[3] + num_steps,
......@@ -549,8 +556,11 @@ class TreeObsForRailEnv(ObservationBuilder):
pow4 *= 4
child_size = (len(tree) - num_features_per_node) // 4
tree_data = tree[:4].tolist()
# print("data",tree_data)
distance_data = [tree[4]]
#print("distance",distance_data)
agent_data = tree[5:num_features_per_node].tolist()
#print("agent_data",agent_data)
for children in range(4):
child_tree = tree[(num_features_per_node + children * child_size):
(num_features_per_node + (children + 1) * child_size)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment