Skip to content
Snippets Groups Projects
Commit 18007817 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Merge branch 'master' into '228_test_sparse_rail_generator'

# Conflicts:
#   flatland/envs/schedule_generators.py
parents 51dad5ea 2b0eaa47
No related branches found
No related tags found
No related merge requests found
......@@ -352,9 +352,13 @@ class GridTransitionMap(TransitionMap):
return is_simple_turn(tmp)
def check_path_exists(self, start: IntVector2DArray, direction: int, end: IntVector2DArray):
# print("_path_exists({},{},{}".format(start, direction, end))
# BFS - Check if a path exists between the 2 nodes
"""
Breath first search for a possible path from one node with a certain orientation to a target node.
:param start: Start cell rom where we want to check the path
:param direction: Start direction for the path we are testing
:param end: Cell that we try to reach from the start cell
:return: True if a path exists, False otherwise
"""
visited = OrderedSet()
stack = [(start, direction)]
while stack:
......@@ -583,7 +587,16 @@ class GridTransitionMap(TransitionMap):
def validate_new_transition(self, prev_pos: IntVector2D, current_pos: IntVector2D,
new_pos: IntVector2D, end_pos: IntVector2D):
"""
Utility function to test that a path drawn by a-start algorithm uses valid transition objects.
We us this to quide a-star as there are many transition elements that are not allowed in RailEnv
:param prev_pos: The previous position we were checking
:param current_pos: The current position we are checking
:param new_pos: Possible child position we move into
:param end_pos: End cell of path we are drawing
:return: True if the transition is valid, False if transition element is illegal
"""
# start by getting direction used to get to current node
# and direction from current node to possible child node
new_dir = get_direction(current_pos, new_pos)
......
......@@ -55,6 +55,13 @@ class DistanceMap:
self.env_width = rail.width
def _compute(self, agents: List[EnvAgent], rail: GridTransitionMap):
"""
This function computes the distance maps for each unique target. Thus if several targets are the same
we only compute the distance for them once and copy to all targets with same position.
:param agents: All the agents in the environment, independent of their current status
:param rail: The rail transition map
"""
self.agents_previous_computation = self.agents
self.distance_map = np.inf * np.ones(shape=(len(agents),
self.env_height,
......
......@@ -112,6 +112,11 @@ def complex_rail_generator(nr_start_goal=1,
sg_new = [start, goal]
def check_all_dist(sg_new):
"""
Function to check the distance betweens start and goal
:param sg_new: start and goal tuple
:return: True if distance is larger than 2, False otherwise
"""
for sg in start_goal:
for i in range(2):
for j in range(2):
......
......@@ -43,8 +43,27 @@ def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float,
def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None, seed: int = 1) -> ScheduleGenerator:
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0):
"""
Generator used to generate the levels of Round 1 in the Flatland Challenge. It can only be used together
with complex_rail_generator. It places agents at end and start points provided by the rail generator.
It assigns speeds to the different agents according to the speed_ratio_map
:param speed_ratio_map: Speed ratios of all agents. They are probabilities of all different speeds and have to
add up to 1.
:param seed: Initiate random seed generator
:return:
"""
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0):
"""
The generator that assigns tasks to all the agents
:param rail: Rail infrastructure given by the rail_generator
:param num_agents: Number of agents to include in the schedule
:param hints: Hints provided by the rail_generator These include positions of start/target positions
:param num_resets: How often the generator has been reset.
:return: Returns the generator to the rail constructor
"""
_runtime_seed = seed + num_resets
np.random.seed(_runtime_seed)
......@@ -65,7 +84,25 @@ def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None, se
def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, seed: int = 1) -> ScheduleGenerator:
"""
This is the schedule generator which is used for Round 2 of the Flatland challenge. It produces schedules
to railway networks provided by sparse_rail_generator.
:param speed_ratio_map: Speed ratios of all agents. They are probabilities of all different speeds and have to
add up to 1.
:param seed: Initiate random seed generator
"""
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0):
"""
The generator that assigns tasks to all the agents
:param rail: Rail infrastructure given by the rail_generator
:param num_agents: Number of agents to include in the schedule
:param hints: Hints provided by the rail_generator These include positions of start/target positions
:param num_resets: How often the generator has been reset.
:return: Returns the generator to the rail constructor
"""
_runtime_seed = seed + num_resets
np.random.seed(_runtime_seed)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment