diff --git a/examples/training_navigation.py b/examples/training_navigation.py index 23970a9059f15ce89b59b1cbcb9b862d248d1cd5..a7920c7da1c330494f2f37e298dd9f691378f115 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -116,7 +116,7 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1): for trials in range(1, n_trials + 1): # Reset environment - obs = env.reset() + obs, _ = env.reset() final_obs = obs.copy() final_obs_next = obs.copy() for a in range(env.get_num_agents()): @@ -148,7 +148,7 @@ for trials in range(1, n_trials + 1): action_dict.update({a: action}) # Environment step - next_obs, all_rewards, done, _ = env.step(action_dict) + (next_obs,_), all_rewards, done, _ = env.step(action_dict) for a in range(env.get_num_agents()): data, distance = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5, diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 8e4be0ba49d8405aa420d2bc8e4854f6300e3837..4fad2c01c4d353b6ba458a94d6fb10085213eb06 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -219,7 +219,8 @@ class TreeObsForRailEnv(ObservationBuilder): if possible_transitions[branch_direction]: new_cell = self._new_position(agent.position, branch_direction) - branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1) + branch_observation, visited = self._explore_branch(handle, new_cell, branch_direction, root_observation, + 1) observation = observation + branch_observation else: num_cells_to_fill_in = 0 @@ -228,7 +229,7 @@ class TreeObsForRailEnv(ObservationBuilder): num_cells_to_fill_in += pow4 pow4 *= 4 observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in - return observation + return observation, visited def _explore_branch(self, handle, position, direction, root_observation, depth): """ @@ -236,7 +237,7 @@ class TreeObsForRailEnv(ObservationBuilder): """ # [Recursive branch opened] if depth >= self.max_depth + 1: - return [] + return [], [] # Continue along direction until next switch or # until no transitions are possible along the current direction (i.e., dead-ends) @@ -377,22 +378,24 @@ class TreeObsForRailEnv(ObservationBuilder): # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes # it back new_cell = self._new_position(position, (branch_direction + 2) % 4) - branch_observation = self._explore_branch(handle, - new_cell, - (branch_direction + 2) % 4, - new_root_observation, - depth + 1) + branch_observation, branch_visited = self._explore_branch(handle, + new_cell, + (branch_direction + 2) % 4, + new_root_observation, + depth + 1) observation = observation + branch_observation - + if len(branch_visited) != 0: + visited.union(branch_visited) elif last_isSwitch and possible_transitions[branch_direction]: new_cell = self._new_position(position, branch_direction) - branch_observation = self._explore_branch(handle, - new_cell, - branch_direction, - new_root_observation, - depth + 1) + branch_observation, branch_visited = self._explore_branch(handle, + new_cell, + branch_direction, + new_root_observation, + depth + 1) observation = observation + branch_observation - + if len(branch_visited) != 0: + visited.union(branch_visited) else: num_cells_to_fill_in = 0 pow4 = 1 @@ -401,7 +404,7 @@ class TreeObsForRailEnv(ObservationBuilder): pow4 *= 4 observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in - return observation + return observation, visited def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0): """ diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index de3401a4bd98a1f114f25d629e7e0f13a7c0337c..749e5e01aa94e8ddb59d33d5787e3bd55d05f9eb 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -315,10 +315,11 @@ class RailEnv(Environment): def _get_observations(self): self.obs_dict = {} + self.debug_obs_dict = {} # for handle in self.agents_handles: for iAgent in range(self.get_num_agents()): - self.obs_dict[iAgent] = self.obs_builder.get(iAgent) - return self.obs_dict + self.obs_dict[iAgent], self.debug_obs_dict[iAgent] = self.obs_builder.get(iAgent) + return self.obs_dict, self.debug_obs_dict def render(self): # TODO: