Commit 71576225 authored by spiglerg's avatar spiglerg
Browse files

added rail_from_list_of_saved_GridTransitionMap_generator rail generator function

parent 9a6efb05
Pipeline #317 failed with stage
in 2 minutes and 3 seconds
......@@ -23,6 +23,12 @@ env = RailEnv(width=20,
height=20,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=10)
# env = RailEnv(width=20,
# height=20,
# rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(['examples/sample_10_10_rail.npy']),
# number_of_agents=10)
env.reset()
env_renderer = RenderTool(env)
......
......@@ -282,7 +282,11 @@ class GridTransitionMap(TransitionMap):
self.grid = new_grid
else:
self.grid = self.grid * 0
if new_grid.dtype == np.uint16:
self.grid = np.zeros((self.height, self.width), dtype=np.uint16)
elif new_grid.dtype == np.uint64:
self.grid = np.zeros((self.height, self.width), dtype=np.uint64)
self.grid[0:min(self.height, new_height), 0:min(self.width, new_width)] = new_grid[0:min(self.height, new_height), 0:min(self.width, new_width)]
# TODO: GIACOMO: is it better to provide those methods with lists of cell_ids
......
......@@ -10,7 +10,7 @@ import numpy as np
from flatland.core.env import Environment
from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.core.transitions import RailEnvTransitions
from flatland.core.transitions import Grid4Transitions, Grid8Transitions, RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
......@@ -75,6 +75,33 @@ def rail_from_GridTransitionMap_generator(rail_map):
return generator
def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames):
"""
Utility to sequentially and cyclically return GridTransitionMap-s from a list of files, on each environment reset.
Parameters
-------
list_of_filenames : list
List of filenames with the saved grids to load.
Returns
-------
function
Generator function that always returns the given `rail_map' object.
"""
def generator(width, height, num_resets=0):
t_utils = RailEnvTransitions()
rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils)
rail_map.load_transition_map(list_of_filenames[num_resets % len(list_of_filenames)], override_gridsize=False)
if rail_map.grid.dtype == np.uint64:
rail_map.transitions = Grid8Transitions()
return rail_map
return generator
"""
def generate_rail_from_list_of_manual_specifications(list_of_specifications)
def generator(width, height, num_resets=0):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment