From de946638b4c24ee94c6f0c612f69111ca1a585f7 Mon Sep 17 00:00:00 2001 From: nimishsantosh107 <nimishsantosh107@icloud.com> Date: Tue, 12 Oct 2021 15:14:37 +0530 Subject: [PATCH] line generators sample orientations fix, rail_from_manual_spec fix --- flatland/envs/line_generators.py | 17 ++++++++++++++--- flatland/envs/rail_generators.py | 17 ++++++++--------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/flatland/envs/line_generators.py b/flatland/envs/line_generators.py index 154a65bb..7f50abb2 100644 --- a/flatland/envs/line_generators.py +++ b/flatland/envs/line_generators.py @@ -67,6 +67,15 @@ class SparseLineGen(BaseLineGen): :param seed: Initiate random seed generator """ + def decide_orientation(self, rail, start, target, possible_orientations, np_random: RandomState) -> int: + feasible_orientations = [] + + for orientation in possible_orientations: + if rail.check_path_exists(start[0], orientation, target[0]): + feasible_orientations.append(orientation) + + return np_random.choice(feasible_orientations) + def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict, num_resets: int, np_random: RandomState) -> Line: """ @@ -116,7 +125,8 @@ class SparseLineGen(BaseLineGen): agent_start = train_stations[city1][agent_start_idx] agent_target = train_stations[city2][agent_target_idx] - agent_orientation = np_random.choice(city1_possible_orientations) + agent_orientation = self.decide_orientation( + rail, agent_start, agent_target, city1_possible_orientations, np_random) else: @@ -125,8 +135,9 @@ class SparseLineGen(BaseLineGen): agent_start = train_stations[city2][agent_start_idx] agent_target = train_stations[city1][agent_target_idx] - - agent_orientation = np_random.choice(city2_possible_orientations) + + agent_orientation = self.decide_orientation( + rail, agent_start, agent_target, city2_possible_orientations, np_random) # agent1 details diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 457574ca..14035e56 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -67,7 +67,7 @@ class EmptyRailGen(RailGen): return grid_map, None -def rail_from_manual_specifications_generator(rail_spec): +def rail_from_manual_specifications_generator(rail_spec, optionals): """ Utility to convert a rail given by manual specification as a map of tuples (cell_type, rotation), to a transition map with the correct 16-bit @@ -107,7 +107,7 @@ def rail_from_manual_specifications_generator(rail_spec): effective_transition_cell = rail_env_transitions.rotate_transition(basic_type_of_cell_, rotation_cell_) rail.set_transitions((r, c), effective_transition_cell) - return [rail, None] + return [rail, optionals] return generator @@ -285,15 +285,14 @@ class SparseRailGen(RailGen): # Fix all transition elements self._fix_transitions(city_cells, inter_city_lines, grid_map, vector_field) return grid_map, {'agents_hints': { - 'num_agents': num_agents, 'city_positions': city_positions, 'train_stations': train_stations, 'city_orientations': city_orientations }} def _generate_random_city_positions(self, num_cities: int, city_radius: int, width: int, - height: int, np_random: RandomState = None) -> ( - IntVector2DArray, IntVector2DArray): + height: int, np_random: RandomState = None) -> Tuple[ + IntVector2DArray, IntVector2DArray]: """ Distribute the cities randomly in the environment while respecting city sizes and guaranteeing that they don't overlap. @@ -352,7 +351,7 @@ class SparseRailGen(RailGen): return city_positions def _generate_evenly_distr_city_positions(self, num_cities: int, city_radius: int, width: int, height: int - ) -> (IntVector2DArray, IntVector2DArray): + ) -> Tuple[IntVector2DArray, IntVector2DArray]: """ Distribute the cities in an evenly spaced grid @@ -399,11 +398,11 @@ class SparseRailGen(RailGen): def _generate_city_connection_points(self, city_positions: IntVector2DArray, city_radius: int, vector_field: IntVector2DArray, rails_between_cities: int, - rail_pairs_in_city: int = 1, np_random: RandomState = None) -> ( + rail_pairs_in_city: int = 1, np_random: RandomState = None) -> Tuple[ List[List[List[IntVector2D]]], List[List[List[IntVector2D]]], List[np.ndarray], - List[Grid4TransitionsEnum]): + List[Grid4TransitionsEnum]]: """ Generate the city connection points. Internal connection points are used to generate the parallel paths within the city. @@ -609,7 +608,7 @@ class SparseRailGen(RailGen): def _build_inner_cities(self, city_positions: IntVector2DArray, inner_connection_points: List[List[List[IntVector2D]]], outer_connection_points: List[List[List[IntVector2D]]], rail_trans: RailEnvTransitions, - grid_map: GridTransitionMap) -> (List[IntVector2DArray], List[List[List[IntVector2D]]]): + grid_map: GridTransitionMap) -> Tuple[List[IntVector2DArray], List[List[List[IntVector2D]]]]: """ Set the parallel tracks within the city. The center track of the city is of the length of the city, the lenght of the tracks decrease by 2 for every parallel track away from the center -- GitLab