diff --git a/examples/training_navigation.py b/examples/training_navigation.py index 4a21ad11417317ab50337116cccb772ac220d7cc..ec19ff20541e94b004b6c788d8e5df707019b2c0 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -24,18 +24,17 @@ transition_probability = [15, # empty cell - Case 0 1] # Case 2b (10) - simple switch mirrored # Example generate a random rail - +""" env = RailEnv(width=10, height=10, rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), - number_of_agents=1) + number_of_agents=5) """ env = RailEnv(width=15, height=15, - rail_generator=complex_rail_generator(nr_start_goal=10, min_dist=5, max_dist=99999, seed=0), + rail_generator=complex_rail_generator(nr_start_goal=3, min_dist=5, max_dist=99999, seed=0), number_of_agents=3) """ -""" env = RailEnv(width=20, height=20, rail_generator=rail_from_list_of_saved_GridTransitionMap_generator( @@ -117,12 +116,14 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1): for trials in range(1, n_trials + 1): # Reset environment - obs, dev_obs = env.reset() - env.dev_obs_dict = dev_obs + obs = env.reset() + final_obs = obs.copy() final_obs_next = obs.copy() + for a in range(env.get_num_agents()): data, distance = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=5, current_depth=0) + data = norm_obs_clip(data) distance = norm_obs_clip(distance) obs[a] = np.concatenate((data, distance)) @@ -150,8 +151,7 @@ for trials in range(1, n_trials + 1): action_dict.update({a: action}) # Environment step - (next_obs, dev_obs), all_rewards, done, _ = env.step(action_dict) - env.dev_obs_dict = dev_obs + next_obs, all_rewards, done, _ = env.step(action_dict) for a in range(env.get_num_agents()): data, distance = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5, current_depth=0) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index dba5faf956f720bcf00b39f8173676f6902c9e5c..b097b5df1e65a249d850e55e164ea5e8124767f9 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -229,7 +229,8 @@ class TreeObsForRailEnv(ObservationBuilder): num_cells_to_fill_in += pow4 pow4 *= 4 observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in - return observation, visited + self.env.dev_obs_dict[handle] = visited + return observation def _explore_branch(self, handle, position, direction, root_observation, depth): """ diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 640eb5c537c303d11e3e31aac9b3e098af1149ed..118ebf4d3ea122519a09c8b7b5a00c53964063a1 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -318,8 +318,8 @@ class RailEnv(Environment): self.debug_obs_dict = {} # for handle in self.agents_handles: for iAgent in range(self.get_num_agents()): - self.obs_dict[iAgent], self.debug_obs_dict[iAgent] = self.obs_builder.get(iAgent) - return self.obs_dict, self.debug_obs_dict + self.obs_dict[iAgent] = self.obs_builder.get(iAgent) + return self.obs_dict def render(self): # TODO: