Commit de946638 authored by nimishsantosh107's avatar nimishsantosh107
Browse files

line generators sample orientations fix, rail_from_manual_spec fix

parent 7e4db340
Pipeline #8729 failed with stages
in 6 minutes and 20 seconds
......@@ -67,6 +67,15 @@ class SparseLineGen(BaseLineGen):
:param seed: Initiate random seed generator
"""
def decide_orientation(self, rail, start, target, possible_orientations, np_random: RandomState) -> int:
feasible_orientations = []
for orientation in possible_orientations:
if rail.check_path_exists(start[0], orientation, target[0]):
feasible_orientations.append(orientation)
return np_random.choice(feasible_orientations)
def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict, num_resets: int,
np_random: RandomState) -> Line:
"""
......@@ -116,7 +125,8 @@ class SparseLineGen(BaseLineGen):
agent_start = train_stations[city1][agent_start_idx]
agent_target = train_stations[city2][agent_target_idx]
agent_orientation = np_random.choice(city1_possible_orientations)
agent_orientation = self.decide_orientation(
rail, agent_start, agent_target, city1_possible_orientations, np_random)
else:
......@@ -125,8 +135,9 @@ class SparseLineGen(BaseLineGen):
agent_start = train_stations[city2][agent_start_idx]
agent_target = train_stations[city1][agent_target_idx]
agent_orientation = np_random.choice(city2_possible_orientations)
agent_orientation = self.decide_orientation(
rail, agent_start, agent_target, city2_possible_orientations, np_random)
# agent1 details
......
......@@ -67,7 +67,7 @@ class EmptyRailGen(RailGen):
return grid_map, None
def rail_from_manual_specifications_generator(rail_spec):
def rail_from_manual_specifications_generator(rail_spec, optionals):
"""
Utility to convert a rail given by manual specification as a map of tuples
(cell_type, rotation), to a transition map with the correct 16-bit
......@@ -107,7 +107,7 @@ def rail_from_manual_specifications_generator(rail_spec):
effective_transition_cell = rail_env_transitions.rotate_transition(basic_type_of_cell_, rotation_cell_)
rail.set_transitions((r, c), effective_transition_cell)
return [rail, None]
return [rail, optionals]
return generator
......@@ -285,15 +285,14 @@ class SparseRailGen(RailGen):
# Fix all transition elements
self._fix_transitions(city_cells, inter_city_lines, grid_map, vector_field)
return grid_map, {'agents_hints': {
'num_agents': num_agents,
'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}}
def _generate_random_city_positions(self, num_cities: int, city_radius: int, width: int,
height: int, np_random: RandomState = None) -> (
IntVector2DArray, IntVector2DArray):
height: int, np_random: RandomState = None) -> Tuple[
IntVector2DArray, IntVector2DArray]:
"""
Distribute the cities randomly in the environment while respecting city sizes and guaranteeing that they
don't overlap.
......@@ -352,7 +351,7 @@ class SparseRailGen(RailGen):
return city_positions
def _generate_evenly_distr_city_positions(self, num_cities: int, city_radius: int, width: int, height: int
) -> (IntVector2DArray, IntVector2DArray):
) -> Tuple[IntVector2DArray, IntVector2DArray]:
"""
Distribute the cities in an evenly spaced grid
......@@ -399,11 +398,11 @@ class SparseRailGen(RailGen):
def _generate_city_connection_points(self, city_positions: IntVector2DArray, city_radius: int,
vector_field: IntVector2DArray, rails_between_cities: int,
rail_pairs_in_city: int = 1, np_random: RandomState = None) -> (
rail_pairs_in_city: int = 1, np_random: RandomState = None) -> Tuple[
List[List[List[IntVector2D]]],
List[List[List[IntVector2D]]],
List[np.ndarray],
List[Grid4TransitionsEnum]):
List[Grid4TransitionsEnum]]:
"""
Generate the city connection points. Internal connection points are used to generate the parallel paths
within the city.
......@@ -609,7 +608,7 @@ class SparseRailGen(RailGen):
def _build_inner_cities(self, city_positions: IntVector2DArray, inner_connection_points: List[List[List[IntVector2D]]],
outer_connection_points: List[List[List[IntVector2D]]], rail_trans: RailEnvTransitions,
grid_map: GridTransitionMap) -> (List[IntVector2DArray], List[List[List[IntVector2D]]]):
grid_map: GridTransitionMap) -> Tuple[List[IntVector2DArray], List[List[List[IntVector2D]]]]:
"""
Set the parallel tracks within the city. The center track of the city is of the length of the city, the lenght
of the tracks decrease by 2 for every parallel track away from the center
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment