From 725b98de64ae9abf10e9ca01844bf3ea16de3e23 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Thu, 13 Jun 2019 18:38:44 +0200 Subject: [PATCH] fixed bugs in tree observation --- examples/training_example.py | 1 - flatland/envs/observations.py | 26 ++++++++++++++++++-------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/examples/training_example.py b/examples/training_example.py index ad222cb..c785804 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -81,7 +81,6 @@ for trials in range(1, n_trials + 1): # reward and whether their are done next_obs, all_rewards, done, _ = env.step(action_dict) TreeObservation.util_print_obs_subtree(next_obs[0], num_features_per_node=8) - print(len(next_obs[0])) # Update replay buffer and train agent for a in range(env.get_num_agents()): agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index bab7e63..4fa0e09 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -212,11 +212,11 @@ class TreeObsForRailEnv(ObservationBuilder): Finally, each node information is composed of 5 floating point values: - #1: + #1: 1 if own target lies on the explored branch - #2: 1 if a target of another agent is detected between the previous node and the current one. + #2: distance toa target of another agent is detected between the previous node and the current one. - #3: 1 if another agent is detected between the previous node and the current one. + #3: distance to another agent is detected between the previous node and the current one. #4: distance of agent to the current branch node @@ -255,6 +255,7 @@ class TreeObsForRailEnv(ObservationBuilder): agent = self.env.agents[handle] # TODO: handle being treated as index possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction)) num_transitions = np.count_nonzero(possible_transitions) + # Root node - current position observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0, 0] @@ -304,7 +305,8 @@ class TreeObsForRailEnv(ObservationBuilder): last_isTarget = False visited = set() - + agent = self.env.agents[handle] + own_target_encountered = np.inf other_agent_encountered = np.inf other_target_encountered = np.inf other_agent_same_direction = 0 @@ -334,9 +336,14 @@ class TreeObsForRailEnv(ObservationBuilder): handle): potential_conflict = 1 - if position in self.location_has_target: + if position in self.location_has_target and position != agent.target: if num_steps < other_target_encountered: other_target_encountered = num_steps + + if position == agent.target: + if num_steps < own_target_encountered: + own_target_encountered = num_steps + # ############################# # ############################# @@ -428,7 +435,7 @@ class TreeObsForRailEnv(ObservationBuilder): """ if last_isTarget: - observation = [0, + observation = [own_target_encountered, other_target_encountered, other_agent_encountered, root_observation[3] + num_steps, @@ -439,7 +446,7 @@ class TreeObsForRailEnv(ObservationBuilder): ] elif last_isTerminal: - observation = [0, + observation = [own_target_encountered, other_target_encountered, other_agent_encountered, np.inf, @@ -449,7 +456,7 @@ class TreeObsForRailEnv(ObservationBuilder): potential_conflict ] else: - observation = [0, + observation = [own_target_encountered, other_target_encountered, other_agent_encountered, root_observation[3] + num_steps, @@ -549,8 +556,11 @@ class TreeObsForRailEnv(ObservationBuilder): pow4 *= 4 child_size = (len(tree) - num_features_per_node) // 4 tree_data = tree[:4].tolist() + # print("data",tree_data) distance_data = [tree[4]] + #print("distance",distance_data) agent_data = tree[5:num_features_per_node].tolist() + #print("agent_data",agent_data) for children in range(4): child_tree = tree[(num_features_per_node + children * child_size): (num_features_per_node + (children + 1) * child_size)] -- GitLab