From 53c34f901e571f13cb00843e876f2af46f2031bc Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Thu, 16 May 2019 18:25:44 +0200 Subject: [PATCH] fixed bugs in get() funciton in observation_builder --- examples/training_navigation.py | 16 ++++++++-------- flatland/envs/observations.py | 3 ++- flatland/envs/rail_env.py | 4 ++-- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/examples/training_navigation.py b/examples/training_navigation.py index 4a21ad11..ec19ff20 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -24,18 +24,17 @@ transition_probability = [15, # empty cell - Case 0 1] # Case 2b (10) - simple switch mirrored # Example generate a random rail - +""" env = RailEnv(width=10, height=10, rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), - number_of_agents=1) + number_of_agents=5) """ env = RailEnv(width=15, height=15, - rail_generator=complex_rail_generator(nr_start_goal=10, min_dist=5, max_dist=99999, seed=0), + rail_generator=complex_rail_generator(nr_start_goal=3, min_dist=5, max_dist=99999, seed=0), number_of_agents=3) """ -""" env = RailEnv(width=20, height=20, rail_generator=rail_from_list_of_saved_GridTransitionMap_generator( @@ -117,12 +116,14 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1): for trials in range(1, n_trials + 1): # Reset environment - obs, dev_obs = env.reset() - env.dev_obs_dict = dev_obs + obs = env.reset() + final_obs = obs.copy() final_obs_next = obs.copy() + for a in range(env.get_num_agents()): data, distance = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=5, current_depth=0) + data = norm_obs_clip(data) distance = norm_obs_clip(distance) obs[a] = np.concatenate((data, distance)) @@ -150,8 +151,7 @@ for trials in range(1, n_trials + 1): action_dict.update({a: action}) # Environment step - (next_obs, dev_obs), all_rewards, done, _ = env.step(action_dict) - env.dev_obs_dict = dev_obs + next_obs, all_rewards, done, _ = env.step(action_dict) for a in range(env.get_num_agents()): data, distance = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5, current_depth=0) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index dba5faf9..b097b5df 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -229,7 +229,8 @@ class TreeObsForRailEnv(ObservationBuilder): num_cells_to_fill_in += pow4 pow4 *= 4 observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in - return observation, visited + self.env.dev_obs_dict[handle] = visited + return observation def _explore_branch(self, handle, position, direction, root_observation, depth): """ diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 640eb5c5..118ebf4d 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -318,8 +318,8 @@ class RailEnv(Environment): self.debug_obs_dict = {} # for handle in self.agents_handles: for iAgent in range(self.get_num_agents()): - self.obs_dict[iAgent], self.debug_obs_dict[iAgent] = self.obs_builder.get(iAgent) - return self.obs_dict, self.debug_obs_dict + self.obs_dict[iAgent] = self.obs_builder.get(iAgent) + return self.obs_dict def render(self): # TODO: -- GitLab