diff --git a/examples/training_navigation.py b/examples/training_navigation.py index 4b73decfe9dda44d41fc85c9ffa93f26a86dd8e0..feb479646fb6d7283e94f87881cc4a82afb215b3 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -23,7 +23,6 @@ transition_probability = [15, # empty cell - Case 0 1, # Case 1c (9) - simple turn left 1] # Case 2b (10) - simple switch mirrored - # Example generate a random rail """ env = RailEnv(width=10, @@ -31,11 +30,10 @@ env = RailEnv(width=10, rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), number_of_agents=1) """ -env = RailEnv(width=15, - height=15, +env = RailEnv(width=10, + height=10, rail_generator=complex_rail_generator(nr_start_goal=3, min_dist=5, max_dist=99999, seed=0), number_of_agents=3) - """ env = RailEnv(width=20, height=20, @@ -61,7 +59,7 @@ scores = [] dones_list = [] action_prob = [0] * 4 agent = Agent(state_size, action_size, "FC", 0) -#agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth')) +#agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint1300.pth')) demo = False @@ -94,15 +92,34 @@ def min_lt(seq, val): return min +def norm_obs_clip(obs, clip_min=-1, clip_max=1): + """ + This function returns the difference between min and max value of an observation + :param obs: Observation that should be normalized + :param clip_min: min value where observation will be clipped + :param clip_max: max value where observation will be clipped + :return: returnes normalized and clipped observatoin + """ + max_obs = max(1, max_lt(obs, 1000)) + min_obs = max(0, min_lt(obs, 0)) + if max_obs == min_obs: + return np.clip(np.array(obs)/ max_obs, clip_min, clip_max) + norm = np.abs(max_obs - min_obs) + if norm == 0: + norm = 1. + return np.clip((np.array(obs)-min_obs)/ norm, clip_min, clip_max) + + for trials in range(1, n_trials + 1): # Reset environment obs = env.reset() final_obs = obs.copy() for a in range(env.get_num_agents()): - norm = max(1, max_lt(obs[a], np.inf)) - obs[a] = np.clip(np.array(obs[a]) / norm, -1, 1) - + data, distance = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=5, current_depth=0) + data = norm_obs_clip(data) + distance = norm_obs_clip(distance) + obs[a] = np.concatenate((data, distance)) # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) score = 0 @@ -119,13 +136,17 @@ for trials in range(1, n_trials + 1): action = agent.act(np.array(obs[a]), eps=eps) action_prob[action] += 1 action_dict.update({a: action}) - #env.obs_builder.util_print_obs_subtree(tree=obs[a], num_features_per_node=5) + # Environment step next_obs, all_rewards, done, _ = env.step(action_dict) for a in range(env.get_num_agents()): - norm = max(1, max_lt(next_obs[a], np.inf)) - next_obs[a] = np.clip(np.array(next_obs[a]) / norm, -1, 1) + data, distance = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5, + current_depth=0) + data = norm_obs_clip(data) + distance = norm_obs_clip(distance) + next_obs[a] = np.concatenate((data, distance)) + # Update replay buffer and train agent for a in range(env.get_num_agents()): if done[a]: @@ -135,7 +156,6 @@ for trials in range(1, n_trials + 1): agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]) score += all_rewards[a] - obs = next_obs.copy() if done['__all__']: env_done = 1 diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index a357f428371e90028156384eadd451a7a565d394..547f349489812e096ac2f6b64f9190ad264536a5 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -140,7 +140,7 @@ class TreeObsForRailEnv(ObservationBuilder): new_cell = self._new_position(position, neigh_direction) if new_cell[0] >= 0 and new_cell[0] < self.env.height and \ - new_cell[1] >= 0 and new_cell[1] < self.env.width: + new_cell[1] >= 0 and new_cell[1] < self.env.width: desired_movement_from_new_cell = (neigh_direction + 2) % 4 @@ -270,7 +270,6 @@ class TreeObsForRailEnv(ObservationBuilder): num_cells_to_fill_in += pow4 pow4 *= 4 observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in - return observation def _explore_branch(self, handle, position, direction, root_observation, depth): @@ -292,21 +291,26 @@ class TreeObsForRailEnv(ObservationBuilder): visited = set() - other_agent_encountered = False - other_target_encountered = False + # other_agent_encountered = False + # other_target_encountered = False + other_agent_encountered = np.inf + other_target_encountered = np.inf + num_steps = 1 while exploring: # ############################# # ############################# # Modify here to compute any useful data required to build the end node's features. This code is called # for each cell visited between the previous branching node and the next switch / target / dead-end. - if position in self.location_has_agent: - other_agent_encountered = True + # other_agent_encountered = True + if num_steps < other_agent_encountered: + other_agent_encountered = num_steps if position in self.location_has_target: - other_target_encountered = True - + # other_target_encountered = True + if num_steps < other_target_encountered: + other_target_encountered = num_steps # ############################# # ############################# @@ -361,7 +365,7 @@ class TreeObsForRailEnv(ObservationBuilder): # ############################# # ############################# # Modify here to append new / different features for each visited cell! - + """ if last_isTarget: observation = [0, 1 if other_target_encountered else 0, @@ -381,12 +385,30 @@ class TreeObsForRailEnv(ObservationBuilder): 1 if other_agent_encountered else 0, root_observation[3] + num_steps, self.distance_map[handle, position[0], position[1], direction]] + """ + if last_isTarget: + observation = [0, + other_target_encountered, + other_agent_encountered, + root_observation[3] + num_steps, + 0] + elif last_isTerminal: + observation = [0, + other_target_encountered, + other_agent_encountered, + np.inf, + np.inf] + else: + observation = [0, + other_target_encountered, + other_agent_encountered, + root_observation[3] + num_steps, + self.distance_map[handle, position[0], position[1], direction]] # ############################# # ############################# new_root_observation = observation[:] - # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation # Get the possible transitions @@ -450,6 +472,40 @@ class TreeObsForRailEnv(ObservationBuilder): prompt=prompt_[children], current_depth=current_depth + 1) + def split_tree(self, tree, num_features_per_node=5, current_depth=0): + """ + + :param tree: + :param num_features_per_node: + :param prompt: + :param current_depth: + :return: + """ + + if len(tree) < num_features_per_node: + return [], [] + + depth = 0 + tmp = len(tree) / num_features_per_node - 1 + pow4 = 4 + while tmp > 0: + tmp -= pow4 + depth += 1 + pow4 *= 4 + child_size = (len(tree) - num_features_per_node) // 4 + tree_data = tree[0:num_features_per_node - 1].tolist() + distance_data = [tree[num_features_per_node - 1]] + for children in range(4): + child_tree = tree[(num_features_per_node + children * child_size): + (num_features_per_node + (children + 1) * child_size)] + tmp_tree_data, tmp_distance_data = self.split_tree(child_tree, + num_features_per_node, + current_depth=current_depth + 1) + if len(tmp_tree_data) > 0: + tree_data.extend(tmp_tree_data) + distance_data.extend(tmp_distance_data) + return tree_data, distance_data + class GlobalObsForRailEnv(ObservationBuilder): """ @@ -490,7 +546,7 @@ class GlobalObsForRailEnv(ObservationBuilder): obs[1][agent.target] += 1 for i in range(len(agents)): - if i != handle: # TODO: handle used as index...? + if i != handle: # TODO: handle used as index...? agent2 = agents[i] obs[3][agent2.position] += 1 obs[2][agent2.target] += 1 diff --git a/notebooks/Editor2.ipynb b/notebooks/Editor2.ipynb index 078afa943f79ee5a999c3e5f41814f9ec177848d..f2481d086d67376ec6018b758e1a88edd4222183 100644 --- a/notebooks/Editor2.ipynb +++ b/notebooks/Editor2.ipynb @@ -160,9 +160,9 @@ "metadata": { "hide_input": false, "kernelspec": { - "display_name": "ve367", + "display_name": "Python 3", "language": "python", - "name": "ve367" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -174,7 +174,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.7" + "version": "3.6.5" }, "latex_envs": { "LaTeX_envs_menu_present": true,