diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 83584a302f0e4064cef187171db291b637126242..97e5a050aa299ae5a1e37763c2ac75cc85f946c2 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -15,6 +15,20 @@ from flatland.envs.agent_utils import RailAgentStatus, EnvAgent from flatland.utils.ordered_set import OrderedSet +Node = collections.namedtuple('Node', 'dist_own_target_encountered ' + 'dist_other_target_encountered ' + 'dist_other_agent_encountered ' + 'dist_potential_conflict ' + 'dist_unusable_switch ' + 'dist_to_next_branch ' + 'dist_min_to_target ' + 'num_agents_same_direction ' + 'num_agents_opposite_direction ' + 'num_agents_malfunctioning ' + 'speed_min_fractional ' + 'num_agents_ready_to_depart ' + 'childs') + class TreeObsForRailEnv(ObservationBuilder): """ TreeObsForRailEnv object. @@ -25,19 +39,7 @@ class TreeObsForRailEnv(ObservationBuilder): For details about the features in the tree observation see the get() function. """ - Node = collections.namedtuple('Node', 'dist_own_target_encountered ' - 'dist_other_target_encountered ' - 'dist_other_agent_encountered ' - 'dist_potential_conflict ' - 'dist_unusable_switch ' - 'dist_to_next_branch ' - 'dist_min_to_target ' - 'num_agents_same_direction ' - 'num_agents_opposite_direction ' - 'num_agents_malfunctioning ' - 'speed_min_fractional ' - 'num_agents_ready_to_depart ' - 'childs') + tree_explored_actions_char = ['L', 'F', 'R', 'B'] @@ -205,7 +207,8 @@ class TreeObsForRailEnv(ObservationBuilder): # Here information about the agent itself is stored distance_map = self.env.distance_map.get() - root_node_observation = TreeObsForRailEnv.Node(dist_own_target_encountered=0, dist_other_target_encountered=0, + # was referring to TreeObsForRailEnv.Node + root_node_observation = Node(dist_own_target_encountered=0, dist_other_target_encountered=0, dist_other_agent_encountered=0, dist_potential_conflict=0, dist_unusable_switch=0, dist_to_next_branch=0, dist_min_to_target=distance_map[ @@ -431,7 +434,8 @@ class TreeObsForRailEnv(ObservationBuilder): dist_to_next_branch = tot_dist dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction] - node = TreeObsForRailEnv.Node(dist_own_target_encountered=own_target_encountered, + # TreeObsForRailEnv.Node + node = Node(dist_own_target_encountered=own_target_encountered, dist_other_target_encountered=other_target_encountered, dist_other_agent_encountered=other_agent_encountered, dist_potential_conflict=potential_conflict, diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 9232914c41c9a1e4c54c625a6f84ae5912d8642e..5fc74da30e8626281a1654345ee7d242e7ab98d9 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -17,17 +17,47 @@ from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map, conne fix_inner_nodes, align_cell_to_city from flatland.envs import persistence + RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]] +""" A rail generator returns a RailGenerator Product, which is just + a GridTransitionMap followed by an (optional) dict/ +""" + RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct] +class RailGen(object): + """ Base class for RailGen(erator) replacement + + WIP to replace bare generators with classes / objects without unnamed local variables + which prevent pickling. + """ + def __init__(self, *args, **kwargs): + """ constructor to record any state to be reused in each "generation" + """ + pass + + def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0, + np_random: RandomState = None) -> RailGeneratorProduct: + pass + + def __call__(self, *args, **kwargs) -> RailGeneratorProduct: + return self.generate(*args, **kwargs) + + + + + def empty_rail_generator() -> RailGenerator: + return EmptyRailGen() + +class EmptyRailGen(RailGen): """ Returns a generator which returns an empty rail mail with no agents. Primarily used by the editor """ - def generator(width: int, height: int, num_agents: int, num_resets: int = 0, + def generate(width: int, height: int, num_agents: int, num_resets: int = 0, np_random: RandomState = None) -> RailGenerator: rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) @@ -36,7 +66,6 @@ def empty_rail_generator() -> RailGenerator: return grid_map, None - return generator def complex_rail_generator(nr_start_goal=1, @@ -255,8 +284,19 @@ def rail_from_file(filename, load_from_package=None) -> RailGenerator: return generator +class RailFromGridGen(RailGen): + def __init__(self, rail_map): + self.rail_map = rail_map + + def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0, + np_random: RandomState = None) -> RailGenerator: + return self.rail_map, None + def rail_from_grid_transition_map(rail_map) -> RailGenerator: + return RailFromGridGen(rail_map) + +def rail_from_grid_transition_map_old(rail_map) -> RailGenerator: """ Utility to convert a rail given by a GridTransitionMap map with the correct 16-bit transitions specifications. @@ -561,13 +601,6 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11, seed=1) -> R -class RailGen(object): - def __init__(self): - pass - - def generate(self): - pass - def sparse_rail_generator(*args, **kwargs): return SparseRailGen(*args, **kwargs) diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index 40a89f799c9f8ae18eb12ce3f1a0f19f9bb76478..1cd10a4bd31162a85345cfe58831b960958c9a59 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -40,6 +40,20 @@ def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, return list(map(lambda index: speeds[index], np_random.choice(nb_classes, nb_agents, p=speed_ratios))) +class BaseSchedGen(object): + def __init__(self, speed_ratio_map: Mapping[float, float] = None, seed: int = 1): + self.speed_ratio_map = speed_ratio_map + self.seed = seed + + def generate(self, rail: GridTransitionMap, num_agents: int, hints: Any=None, num_resets: int = 0, + np_random: RandomState = None) -> Schedule: + pass + + def __call__(self, *args, **kwargs): + return self.generate(*args, **kwargs) + + + def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None, seed: int = 1) -> ScheduleGenerator: """ @@ -88,6 +102,10 @@ def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None, se def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, seed: int = 1) -> ScheduleGenerator: + return SparseSchedGen(speed_ratio_map, seed) + + +class SparseSchedGen(BaseSchedGen): """ This is the schedule generator which is used for Round 2 of the Flatland challenge. It produces schedules @@ -97,7 +115,7 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, see :param seed: Initiate random seed generator """ - def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0, + def generate(self, rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0, np_random: RandomState = None) -> Schedule: """ @@ -109,7 +127,7 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, see :return: Returns the generator to the rail constructor """ - _runtime_seed = seed + num_resets + _runtime_seed = self.seed + num_resets train_stations = hints['train_stations'] city_positions = hints['city_positions'] @@ -162,8 +180,8 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, see agents_direction.append(agent_orientation) # Orient the agent correctly - if speed_ratio_map: - speeds = speed_initialization_helper(num_agents, speed_ratio_map, seed=_runtime_seed, np_random=np_random) + if self.speed_ratio_map: + speeds = speed_initialization_helper(num_agents, self.speed_ratio_map, seed=_runtime_seed, np_random=np_random) else: speeds = [1.0] * len(agents_position) @@ -178,11 +196,13 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, see agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None, max_episode_steps=max_episode_steps) - return generator +def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None, seed: int = 1) -> ScheduleGenerator: + return RandomSchedGen(speed_ratio_map, seed) -def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] = None, - seed: int = 1) -> ScheduleGenerator: + +class RandomSchedGen(BaseSchedGen): + """ Given a `rail` GridTransitionMap, return a random placement of agents (initial position, direction and target). @@ -197,9 +217,9 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] = initial positions, directions, targets speeds """ - def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0, + def generate(self, rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0, np_random: RandomState = None) -> Schedule: - _runtime_seed = seed + num_resets + _runtime_seed = self.seed + num_resets valid_positions = [] for r in range(rail.height): @@ -270,7 +290,8 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] = agents_direction[i] = valid_starting_directions[ np_random.choice(len(valid_starting_directions), 1)[0]] - agents_speed = speed_initialization_helper(num_agents, speed_ratio_map, seed=_runtime_seed, np_random=np_random) + agents_speed = speed_initialization_helper(num_agents, self.speed_ratio_map, seed=_runtime_seed, + np_random=np_random) # Compute max number of steps with given schedule extra_time_factor = 1.5 # Factor to allow for more then minimal time @@ -280,7 +301,6 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] = agent_targets=agents_target, agent_speeds=agents_speed, agent_malfunction_rates=None, max_episode_steps=max_episode_steps) - return generator def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator: