From 20d93455bcd09aef130d5deaa3900940c4b2b84b Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Thu, 13 Jun 2019 16:32:50 +0200 Subject: [PATCH] Added potential conflict to tree observation as an 8th feature. ATTENTION this means that the observation space dimension has increased! will still check that this is handled correctly everywhere but looks good. --- examples/training_example.py | 17 ++++++++++++----- flatland/envs/env_utils.py | 8 +++++--- flatland/envs/observations.py | 35 ++++++++++++++++++++--------------- 3 files changed, 37 insertions(+), 23 deletions(-) diff --git a/examples/training_example.py b/examples/training_example.py index 1342107..ad222cb 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -1,6 +1,8 @@ import numpy as np from flatland.envs.generators import complex_rail_generator +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import DummyPredictorForRailEnv from flatland.envs.rail_env import RailEnv np.random.seed(1) @@ -8,10 +10,13 @@ np.random.seed(1) # Use the complex_rail_generator to generate feasible network configurations with corresponding tasks # Training on simple small tasks is the best way to get familiar with the environment # -env = RailEnv(width=15, - height=15, - rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0), - number_of_agents=5) + +TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv()) +env = RailEnv(width=20, + height=20, + rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), + obs_builder_object=TreeObservation, + number_of_agents=2) # Import your own Agent or use RLlib to train agents on Flatland @@ -56,6 +61,7 @@ n_trials = 5 # Empty dictionary for all agent action action_dict = dict() print("Starting Training...") + for trials in range(1, n_trials + 1): # Reset environment and get initial observations for all agents @@ -74,7 +80,8 @@ for trials in range(1, n_trials + 1): # Environment step which returns the observations for all agents, their corresponding # reward and whether their are done next_obs, all_rewards, done, _ = env.step(action_dict) - + TreeObservation.util_print_obs_subtree(next_obs[0], num_features_per_node=8) + print(len(next_obs[0])) # Update replay buffer and train agent for a in range(env.get_num_agents()): agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])) diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py index 35e6560..c9595b7 100644 --- a/flatland/envs/env_utils.py +++ b/flatland/envs/env_utils.py @@ -89,10 +89,12 @@ def coordinate_to_position(width, coords): :param coords: :return: """ - position = [] + position = np.empty(len(coords), dtype=int) + idx = 0 for t in coords: - position.append((t[1] * width + t[0])) - return np.asarray(position).flatten() + position[idx] = int(t[1] * width + t[0]) + idx += 1 + return position class AStarNode(): diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 7bc2f99..e0f0c7f 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -27,7 +27,7 @@ class TreeObsForRailEnv(ObservationBuilder): for i in range(self.max_depth + 1): size += pow4 pow4 *= 4 - self.observation_dim = 7 + self.observation_dim = 8 self.observation_space = [size * self.observation_dim] self.location_has_agent = {} self.location_has_agent_direction = {} @@ -187,10 +187,7 @@ class TreeObsForRailEnv(ObservationBuilder): dir_list.append(self.predictions[a][t][3]) self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)}) self.predicted_dir.update({t: dir_list}) - - pred_pos = np.concatenate([[x[:, 1:3]] for x in list(self.predictions.values())], axis=0) - pred_pos = list(map(list, zip(*pred_pos))) - + self.max_prediction_depth = len(self.predicted_pos) observations = {} for h in handles: observations[h] = self.get(h) @@ -256,7 +253,7 @@ class TreeObsForRailEnv(ObservationBuilder): possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction)) num_transitions = np.count_nonzero(possible_transitions) # Root node - current position - observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0] + observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0, 0] root_observation = observation[:] visited = set() @@ -309,7 +306,7 @@ class TreeObsForRailEnv(ObservationBuilder): other_target_encountered = np.inf other_agent_same_direction = 0 other_agent_opposite_direction = 0 - + potential_conflict = 0 num_steps = 1 while exploring: # ############################# @@ -329,6 +326,10 @@ class TreeObsForRailEnv(ObservationBuilder): other_agent_opposite_direction += 1 # Register possible conflict + if self.predictor and num_steps < self.max_prediction_depth: + if coordinate_to_position(self.env.width, [position]) in np.delete(self.predicted_pos[num_steps], + handle): + potential_conflict = 1 if position in self.location_has_target: if num_steps < other_target_encountered: @@ -430,7 +431,8 @@ class TreeObsForRailEnv(ObservationBuilder): root_observation[3] + num_steps, 0, other_agent_same_direction, - other_agent_opposite_direction + other_agent_opposite_direction, + potential_conflict ] elif last_isTerminal: @@ -440,7 +442,8 @@ class TreeObsForRailEnv(ObservationBuilder): np.inf, np.inf, other_agent_same_direction, - other_agent_opposite_direction + other_agent_opposite_direction, + potential_conflict ] else: observation = [0, @@ -449,7 +452,8 @@ class TreeObsForRailEnv(ObservationBuilder): root_observation[3] + num_steps, self.distance_map[handle, position[0], position[1], direction], other_agent_same_direction, - other_agent_opposite_direction + other_agent_opposite_direction, + potential_conflict ] # ############################# # ############################# @@ -493,7 +497,7 @@ class TreeObsForRailEnv(ObservationBuilder): return observation, visited - def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0): + def util_print_obs_subtree(self, tree, num_features_per_node=8, prompt='', current_depth=0): """ Utility function to pretty-print tree observations returned by this object. """ @@ -520,7 +524,7 @@ class TreeObsForRailEnv(ObservationBuilder): prompt=prompt_[children], current_depth=current_depth + 1) - def split_tree(self, tree, num_features_per_node=7, current_depth=0): + def split_tree(self, tree, num_features_per_node=8, current_depth=0): """ :param tree: @@ -541,9 +545,10 @@ class TreeObsForRailEnv(ObservationBuilder): depth += 1 pow4 *= 4 child_size = (len(tree) - num_features_per_node) // 4 - tree_data = tree[0:num_features_per_node - 3].tolist() - distance_data = [tree[num_features_per_node - 3]] - agent_data = tree[-2:].tolist() + tree_data = tree[0:4].tolist() + distance_data = [tree[4]] + agent_data = tree[-3:].tolist() + for children in range(4): child_tree = tree[(num_features_per_node + children * child_size): (num_features_per_node + (children + 1) * child_size)] -- GitLab