diff --git a/examples/temporary_example.py b/examples/temporary_example.py index 1d13d2ce865658d6f421c7e3feef9414ff14ffde..662bfe94d0cd72ff685d5546efbc70a3d641c057 100644 --- a/examples/temporary_example.py +++ b/examples/temporary_example.py @@ -9,24 +9,14 @@ from flatland.utils.rendertools import * random.seed(0) np.random.seed(0) -""" -transition_probability = [1.0, # empty cell - Case 0 - 3.0, # Case 1 - straight - 1.0, # Case 2 - simple switch - 3.0, # Case 3 - diamond drossing - 2.0, # Case 4 - single slip - 1.0, # Case 5 - double slip - 1.0, # Case 6 - symmetrical - 1.0] # Case 7 - dead end -""" transition_probability = [1.0, # empty cell - Case 0 1.0, # Case 1 - straight - 0.5, # Case 2 - simple switch - 0.2, # Case 3 - diamond drossing + 1.0, # Case 2 - simple switch + 0.3, # Case 3 - diamond drossing 0.5, # Case 4 - single slip - 0.1, # Case 5 - double slip + 0.5, # Case 5 - double slip 0.2, # Case 6 - symmetrical - 0.01] # Case 7 - dead end + 0.0] # Case 7 - dead end # Example generate a random rail env = RailEnv(width=20, @@ -38,12 +28,12 @@ env.reset() env_renderer = RenderTool(env) env_renderer.renderEnv(show=True) +""" # Example generate a rail given a manual specification, # a map of tuples (cell_type, rotation) specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)], [(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)]] -""" env = RailEnv(width=6, height=2, rail_generator=rail_from_manual_specifications_generator(specs), @@ -56,20 +46,20 @@ env.agents_target[0] = [1, 1] env.agents_direction[0] = 1 # TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too! env.obs_builder.reset() -#""" - +""" env = RailEnv(width=7, height=7, rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), number_of_agents=2) -# TODO: delete next line -#for i in range(4): -# print(env.obs_builder.distance_map[0, :, :, i]) +# Print the distance map of each cell to the target of the first agent +# for i in range(4): +# print(env.obs_builder.distance_map[0, :, :, i]) +# Print the observation vector for agent 0 obs, all_rewards, done, _ = env.step({0:0}) for i in range(env.number_of_agents): - env.obs_builder.util_print_obs_subtree(tree=obs[i], num_elements_per_node=5) + env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=5) env_renderer = RenderTool(env) env_renderer.renderEnv(show=True) diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index a0def07ebbf7e598ba04de6094ba08b93ff06590..d7bee9301ea8b4d17ef431d4616c13c19490669f 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -383,15 +383,15 @@ class TreeObsForRailEnv(ObservationBuilder): return observation - def util_print_obs_subtree(self, tree, num_elements_per_node=5, prompt='', current_depth=0): + def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0): """ Utility function to pretty-print tree observations returned by this object. """ - if len(tree) < num_elements_per_node: + if len(tree) < num_features_per_node: return depth = 0 - tmp = len(tree)/num_elements_per_node-1 + tmp = len(tree)/num_features_per_node-1 pow4 = 4 while tmp > 0: tmp -= pow4 @@ -400,12 +400,12 @@ class TreeObsForRailEnv(ObservationBuilder): prompt_ = ['L:', 'F:', 'R:', 'B:'] - print(" "*current_depth + prompt, tree[0:num_elements_per_node]) - child_size = (len(tree)-num_elements_per_node)//4 + print(" "*current_depth + prompt, tree[0:num_features_per_node]) + child_size = (len(tree)-num_features_per_node)//4 for children in range(4): - child_tree = tree[(num_elements_per_node+children*child_size): - (num_elements_per_node+(children+1)*child_size)] + child_tree = tree[(num_features_per_node+children*child_size): + (num_features_per_node+(children+1)*child_size)] self.util_print_obs_subtree(child_tree, - num_elements_per_node, + num_features_per_node, prompt=prompt_[children], current_depth=current_depth+1) diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index d3fcf5c8467586053ca4ab624f9b8536bdfba2de..73bb6eeffd9bf61d418adf67e9d49bec5a12234b 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -118,7 +118,7 @@ class GridTransitionMap(TransitionMap): Width of the grid. height : int Height of the grid. - transitions_class : Transitions object + transitions : Transitions object The Transitions object to use to encode/decode transitions over the grid. @@ -243,6 +243,47 @@ class GridTransitionMap(TransitionMap): return self.transitions.set_transition(self.grid[cell_id[0]][cell_id[1]], cell_id[2], transition_index, new_transition) + def save_transition_map(self, filename): + """ + Save the transitions grid as `filename', in npy format. + + Parameters + ---------- + filename : string + Name of the file to which to save the transitions grid. + + """ + np.save(filename, self.grid) + + def load_transition_map(self, filename, override_gridsize=True): + """ + Load the transitions grid from `filename' (npy format). + The load function only updates the transitions grid, and possibly width and height, but the object has to be + initialized with the correct `transitions' object anyway. + + Parameters + ---------- + filename : string + Name of the file from which to load the transitions grid. + override_gridsize : bool + If override_gridsize=True, the width and height of the GridTransitionMap object are replaced with the size + of the map loaded from `filename'. If override_gridsize=False, the transitions grid is either cropped (if + the grid size is larger than (height,width) ) or padded with zeros (if the grid size is smaller than (height,width) ) + + """ + new_grid = np.load(filename) + + new_height = new_grid.shape[0] + new_width = new_grid.shape[1] + + if override_gridsize: + self.width = new_width + self.height = new_height + self.grid = new_grid + + else: + self.grid = self.grid * 0 + self.grid[0:min(self.height, new_height), 0:min(self.width, new_width)] = new_grid[0:min(self.height, new_height), 0:min(self.width, new_width)] # TODO: GIACOMO: is it better to provide those methods with lists of cell_ids # (most general implementation) or to make Grid-class specific methods for