diff --git a/examples/training_example.py b/examples/training_example.py index 1342107767a599cb1440fd08f506903a5059492e..ad222cb0889a92023c726f68ce8ecff0d2dd6cee 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -1,6 +1,8 @@ import numpy as np from flatland.envs.generators import complex_rail_generator +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import DummyPredictorForRailEnv from flatland.envs.rail_env import RailEnv np.random.seed(1) @@ -8,10 +10,13 @@ np.random.seed(1) # Use the complex_rail_generator to generate feasible network configurations with corresponding tasks # Training on simple small tasks is the best way to get familiar with the environment # -env = RailEnv(width=15, - height=15, - rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0), - number_of_agents=5) + +TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv()) +env = RailEnv(width=20, + height=20, + rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), + obs_builder_object=TreeObservation, + number_of_agents=2) # Import your own Agent or use RLlib to train agents on Flatland @@ -56,6 +61,7 @@ n_trials = 5 # Empty dictionary for all agent action action_dict = dict() print("Starting Training...") + for trials in range(1, n_trials + 1): # Reset environment and get initial observations for all agents @@ -74,7 +80,8 @@ for trials in range(1, n_trials + 1): # Environment step which returns the observations for all agents, their corresponding # reward and whether their are done next_obs, all_rewards, done, _ = env.step(action_dict) - + TreeObservation.util_print_obs_subtree(next_obs[0], num_features_per_node=8) + print(len(next_obs[0])) # Update replay buffer and train agent for a in range(env.get_num_agents()): agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])) diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py index 35e6560818acb0d5d7b2bf56bfd8feb33f26c87a..c9595b7693497daa7db110b8fc8b4ae040d39cc9 100644 --- a/flatland/envs/env_utils.py +++ b/flatland/envs/env_utils.py @@ -89,10 +89,12 @@ def coordinate_to_position(width, coords): :param coords: :return: """ - position = [] + position = np.empty(len(coords), dtype=int) + idx = 0 for t in coords: - position.append((t[1] * width + t[0])) - return np.asarray(position).flatten() + position[idx] = int(t[1] * width + t[0]) + idx += 1 + return position class AStarNode(): diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 7bc2f994f8ae05f86e569cdb64bae08f6dc877f2..e0f0c7f1b62bcd81ed2db5eb8f7bbe67bcee8bcf 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -27,7 +27,7 @@ class TreeObsForRailEnv(ObservationBuilder): for i in range(self.max_depth + 1): size += pow4 pow4 *= 4 - self.observation_dim = 7 + self.observation_dim = 8 self.observation_space = [size * self.observation_dim] self.location_has_agent = {} self.location_has_agent_direction = {} @@ -187,10 +187,7 @@ class TreeObsForRailEnv(ObservationBuilder): dir_list.append(self.predictions[a][t][3]) self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)}) self.predicted_dir.update({t: dir_list}) - - pred_pos = np.concatenate([[x[:, 1:3]] for x in list(self.predictions.values())], axis=0) - pred_pos = list(map(list, zip(*pred_pos))) - + self.max_prediction_depth = len(self.predicted_pos) observations = {} for h in handles: observations[h] = self.get(h) @@ -256,7 +253,7 @@ class TreeObsForRailEnv(ObservationBuilder): possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction)) num_transitions = np.count_nonzero(possible_transitions) # Root node - current position - observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0] + observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0, 0] root_observation = observation[:] visited = set() @@ -309,7 +306,7 @@ class TreeObsForRailEnv(ObservationBuilder): other_target_encountered = np.inf other_agent_same_direction = 0 other_agent_opposite_direction = 0 - + potential_conflict = 0 num_steps = 1 while exploring: # ############################# @@ -329,6 +326,10 @@ class TreeObsForRailEnv(ObservationBuilder): other_agent_opposite_direction += 1 # Register possible conflict + if self.predictor and num_steps < self.max_prediction_depth: + if coordinate_to_position(self.env.width, [position]) in np.delete(self.predicted_pos[num_steps], + handle): + potential_conflict = 1 if position in self.location_has_target: if num_steps < other_target_encountered: @@ -430,7 +431,8 @@ class TreeObsForRailEnv(ObservationBuilder): root_observation[3] + num_steps, 0, other_agent_same_direction, - other_agent_opposite_direction + other_agent_opposite_direction, + potential_conflict ] elif last_isTerminal: @@ -440,7 +442,8 @@ class TreeObsForRailEnv(ObservationBuilder): np.inf, np.inf, other_agent_same_direction, - other_agent_opposite_direction + other_agent_opposite_direction, + potential_conflict ] else: observation = [0, @@ -449,7 +452,8 @@ class TreeObsForRailEnv(ObservationBuilder): root_observation[3] + num_steps, self.distance_map[handle, position[0], position[1], direction], other_agent_same_direction, - other_agent_opposite_direction + other_agent_opposite_direction, + potential_conflict ] # ############################# # ############################# @@ -493,7 +497,7 @@ class TreeObsForRailEnv(ObservationBuilder): return observation, visited - def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0): + def util_print_obs_subtree(self, tree, num_features_per_node=8, prompt='', current_depth=0): """ Utility function to pretty-print tree observations returned by this object. """ @@ -520,7 +524,7 @@ class TreeObsForRailEnv(ObservationBuilder): prompt=prompt_[children], current_depth=current_depth + 1) - def split_tree(self, tree, num_features_per_node=7, current_depth=0): + def split_tree(self, tree, num_features_per_node=8, current_depth=0): """ :param tree: @@ -541,9 +545,10 @@ class TreeObsForRailEnv(ObservationBuilder): depth += 1 pow4 *= 4 child_size = (len(tree) - num_features_per_node) // 4 - tree_data = tree[0:num_features_per_node - 3].tolist() - distance_data = [tree[num_features_per_node - 3]] - agent_data = tree[-2:].tolist() + tree_data = tree[0:4].tolist() + distance_data = [tree[4]] + agent_data = tree[-3:].tolist() + for children in range(4): child_tree = tree[(num_features_per_node + children * child_size): (num_features_per_node + (children + 1) * child_size)]