Skip to content
Snippets Groups Projects
Commit 9a6efb05 authored by spiglerg's avatar spiglerg
Browse files

added save/load gridmap for GridTransitionMap

parent 13ebd009
No related branches found
No related tags found
No related merge requests found
......@@ -9,24 +9,14 @@ from flatland.utils.rendertools import *
random.seed(0)
np.random.seed(0)
"""
transition_probability = [1.0, # empty cell - Case 0
3.0, # Case 1 - straight
1.0, # Case 2 - simple switch
3.0, # Case 3 - diamond drossing
2.0, # Case 4 - single slip
1.0, # Case 5 - double slip
1.0, # Case 6 - symmetrical
1.0] # Case 7 - dead end
"""
transition_probability = [1.0, # empty cell - Case 0
1.0, # Case 1 - straight
0.5, # Case 2 - simple switch
0.2, # Case 3 - diamond drossing
1.0, # Case 2 - simple switch
0.3, # Case 3 - diamond drossing
0.5, # Case 4 - single slip
0.1, # Case 5 - double slip
0.5, # Case 5 - double slip
0.2, # Case 6 - symmetrical
0.01] # Case 7 - dead end
0.0] # Case 7 - dead end
# Example generate a random rail
env = RailEnv(width=20,
......@@ -38,12 +28,12 @@ env.reset()
env_renderer = RenderTool(env)
env_renderer.renderEnv(show=True)
"""
# Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation)
specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
[(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)]]
"""
env = RailEnv(width=6,
height=2,
rail_generator=rail_from_manual_specifications_generator(specs),
......@@ -56,20 +46,20 @@ env.agents_target[0] = [1, 1]
env.agents_direction[0] = 1
# TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too!
env.obs_builder.reset()
#"""
"""
env = RailEnv(width=7,
height=7,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=2)
# TODO: delete next line
#for i in range(4):
# print(env.obs_builder.distance_map[0, :, :, i])
# Print the distance map of each cell to the target of the first agent
# for i in range(4):
# print(env.obs_builder.distance_map[0, :, :, i])
# Print the observation vector for agent 0
obs, all_rewards, done, _ = env.step({0:0})
for i in range(env.number_of_agents):
env.obs_builder.util_print_obs_subtree(tree=obs[i], num_elements_per_node=5)
env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=5)
env_renderer = RenderTool(env)
env_renderer.renderEnv(show=True)
......
......@@ -383,15 +383,15 @@ class TreeObsForRailEnv(ObservationBuilder):
return observation
def util_print_obs_subtree(self, tree, num_elements_per_node=5, prompt='', current_depth=0):
def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0):
"""
Utility function to pretty-print tree observations returned by this object.
"""
if len(tree) < num_elements_per_node:
if len(tree) < num_features_per_node:
return
depth = 0
tmp = len(tree)/num_elements_per_node-1
tmp = len(tree)/num_features_per_node-1
pow4 = 4
while tmp > 0:
tmp -= pow4
......@@ -400,12 +400,12 @@ class TreeObsForRailEnv(ObservationBuilder):
prompt_ = ['L:', 'F:', 'R:', 'B:']
print(" "*current_depth + prompt, tree[0:num_elements_per_node])
child_size = (len(tree)-num_elements_per_node)//4
print(" "*current_depth + prompt, tree[0:num_features_per_node])
child_size = (len(tree)-num_features_per_node)//4
for children in range(4):
child_tree = tree[(num_elements_per_node+children*child_size):
(num_elements_per_node+(children+1)*child_size)]
child_tree = tree[(num_features_per_node+children*child_size):
(num_features_per_node+(children+1)*child_size)]
self.util_print_obs_subtree(child_tree,
num_elements_per_node,
num_features_per_node,
prompt=prompt_[children],
current_depth=current_depth+1)
......@@ -118,7 +118,7 @@ class GridTransitionMap(TransitionMap):
Width of the grid.
height : int
Height of the grid.
transitions_class : Transitions object
transitions : Transitions object
The Transitions object to use to encode/decode transitions over the
grid.
......@@ -243,6 +243,47 @@ class GridTransitionMap(TransitionMap):
return
self.transitions.set_transition(self.grid[cell_id[0]][cell_id[1]], cell_id[2], transition_index, new_transition)
def save_transition_map(self, filename):
"""
Save the transitions grid as `filename', in npy format.
Parameters
----------
filename : string
Name of the file to which to save the transitions grid.
"""
np.save(filename, self.grid)
def load_transition_map(self, filename, override_gridsize=True):
"""
Load the transitions grid from `filename' (npy format).
The load function only updates the transitions grid, and possibly width and height, but the object has to be
initialized with the correct `transitions' object anyway.
Parameters
----------
filename : string
Name of the file from which to load the transitions grid.
override_gridsize : bool
If override_gridsize=True, the width and height of the GridTransitionMap object are replaced with the size
of the map loaded from `filename'. If override_gridsize=False, the transitions grid is either cropped (if
the grid size is larger than (height,width) ) or padded with zeros (if the grid size is smaller than (height,width) )
"""
new_grid = np.load(filename)
new_height = new_grid.shape[0]
new_width = new_grid.shape[1]
if override_gridsize:
self.width = new_width
self.height = new_height
self.grid = new_grid
else:
self.grid = self.grid * 0
self.grid[0:min(self.height, new_height), 0:min(self.width, new_width)] = new_grid[0:min(self.height, new_height), 0:min(self.width, new_width)]
# TODO: GIACOMO: is it better to provide those methods with lists of cell_ids
# (most general implementation) or to make Grid-class specific methods for
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment