Skip to content
Snippets Groups Projects
Commit 285e714f authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

added funcitonality to a-star to avoid other tracks when looking for paths

parent b4aca5f6
No related branches found
No related tags found
No related merge requests found
...@@ -2,7 +2,7 @@ import time ...@@ -2,7 +2,7 @@ import time
import numpy as np import numpy as np
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.rail_generators import sparse_rail_generator
...@@ -32,17 +32,17 @@ speed_ration_map = {1.: 0.25, # Fast passenger train ...@@ -32,17 +32,17 @@ speed_ration_map = {1.: 0.25, # Fast passenger train
env = RailEnv(width=100, env = RailEnv(width=100,
height=100, height=100,
rail_generator=sparse_rail_generator(max_num_cities=20, rail_generator=sparse_rail_generator(max_num_cities=30,
# Number of cities in map (where train stations are) # Number of cities in map (where train stations are)
seed=1, # Random seed seed=14, # Random seed
grid_mode=False, grid_mode=False,
max_rails_between_cities=5, max_rails_between_cities=2,
max_rails_in_city=8, max_rails_in_city=6,
), ),
schedule_generator=sparse_schedule_generator(speed_ration_map), schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=50, number_of_agents=100,
stochastic_data=stochastic_data, # Malfunction data generator stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=TreeObservation, obs_builder_object=GlobalObsForRailEnv(),
remove_agents_at_target=True remove_agents_at_target=True
) )
...@@ -51,8 +51,8 @@ env = RailEnv(width=100, ...@@ -51,8 +51,8 @@ env = RailEnv(width=100,
env_renderer = RenderTool(env, gl="PILSVG", env_renderer = RenderTool(env, gl="PILSVG",
agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX, agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX,
show_debug=True, show_debug=True,
screen_height=800, screen_height=1000,
screen_width=800) screen_width=1000)
# Import your own Agent or use RLlib to train agents on Flatland # Import your own Agent or use RLlib to train agents on Flatland
......
import numpy as np
from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance
from flatland.core.grid.grid_utils import IntVector2DArray from flatland.core.grid.grid_utils import IntVector2DArray
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
...@@ -36,10 +38,11 @@ class AStarNode: ...@@ -36,10 +38,11 @@ class AStarNode:
def a_star(grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D, def a_star(grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, avoid_rails=False,
respect_transition_validity=True, forbidden_cells: IntVector2DArray = None) -> IntVector2DArray: respect_transition_validity=True, forbidden_cells: IntVector2DArray = None) -> IntVector2DArray:
""" """
:param avoid_rails:
:param grid_map: Grid Map where the path is found in :param grid_map: Grid Map where the path is found in
:param start: Start positions as (row,column) :param start: Start positions as (row,column)
:param end: End position as (row,column) :param end: End position as (row,column)
...@@ -129,7 +132,10 @@ def a_star(grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D, ...@@ -129,7 +132,10 @@ def a_star(grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D,
# create the f, g, and h values # create the f, g, and h values
child.g = current_node.g + 1.0 child.g = current_node.g + 1.0
# this heuristic avoids diagonal paths # this heuristic avoids diagonal paths
child.h = a_star_distance_function(child.pos, end_node.pos) if avoid_rails:
child.h = a_star_distance_function(child.pos, end_node.pos) + np.clip(grid_map.grid[child.pos], 0, 1)
else:
child.h = a_star_distance_function(child.pos, end_node.pos)
child.f = child.g + child.h child.f = child.g + child.h
# already in the open list? # already in the open list?
......
...@@ -19,11 +19,12 @@ def connect_rail_in_grid_map(grid_map: GridTransitionMap, start: IntVector2D, en ...@@ -19,11 +19,12 @@ def connect_rail_in_grid_map(grid_map: GridTransitionMap, start: IntVector2D, en
rail_trans: RailEnvTransitions, rail_trans: RailEnvTransitions,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance,
flip_start_node_trans: bool = False, flip_end_node_trans: bool = False, flip_start_node_trans: bool = False, flip_end_node_trans: bool = False,
respect_transition_validity: bool = True, respect_transition_validity: bool = True, forbidden_cells: IntVector2DArray = None,
forbidden_cells: IntVector2DArray = None) -> IntVector2DArray: avoid_rail=False) -> IntVector2DArray:
""" """
Creates a new path [start,end] in `grid_map.grid`, based on rail_trans, and Creates a new path [start,end] in `grid_map.grid`, based on rail_trans, and
returns the path created as a list of positions. returns the path created as a list of positions.
:param avoid_rail:
:param rail_trans: basic rail transition object :param rail_trans: basic rail transition object
:param grid_map: grid map :param grid_map: grid map
:param start: start position of rail :param start: start position of rail
...@@ -36,7 +37,8 @@ def connect_rail_in_grid_map(grid_map: GridTransitionMap, start: IntVector2D, en ...@@ -36,7 +37,8 @@ def connect_rail_in_grid_map(grid_map: GridTransitionMap, start: IntVector2D, en
:return: List of cells in the path :return: List of cells in the path
""" """
path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function, respect_transition_validity, path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function, avoid_rail,
respect_transition_validity,
forbidden_cells) forbidden_cells)
if len(path) < 2: if len(path) < 2:
return [] return []
......
...@@ -561,7 +561,8 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ ...@@ -561,7 +561,8 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
city_radius = int(np.ceil((max_rails_in_city + 2) / 2.0)) + 2 city_radius = int(np.ceil((max_rails_in_city + 2) / 2.0)) + 2
vector_field = np.zeros(shape=(height, width)) - 1. vector_field = np.zeros(shape=(height, width)) - 1.
min_nr_rails_in_city = 3 min_nr_rails_in_city = 2
max_nr_rail_in_city = 6
rails_in_city = min_nr_rails_in_city if max_rails_in_city < min_nr_rails_in_city else max_rails_in_city rails_in_city = min_nr_rails_in_city if max_rails_in_city < min_nr_rails_in_city else max_rails_in_city
rails_between_cities = rails_in_city if max_rails_between_cities > rails_in_city else max_rails_between_cities rails_between_cities = rails_in_city if max_rails_between_cities > rails_in_city else max_rails_between_cities
...@@ -604,7 +605,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ ...@@ -604,7 +605,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
# Populate cities # Populate cities
train_stations = _set_trainstation_positions(city_positions, city_radius, free_rails) train_stations = _set_trainstation_positions(city_positions, city_radius, free_rails)
# Fix all transition elements # Fix all transition elements
_fix_transitions(city_cells, inter_city_lines, grid_map, vector_field) _fix_transitions(city_cells, inter_city_lines, grid_map, vector_field, rail_trans)
# Generate start target pairs # Generate start target pairs
agent_start_targets_cities = _generate_start_target_pairs(num_agents, num_cities, train_stations, agent_start_targets_cities = _generate_start_target_pairs(num_agents, num_cities, train_stations,
city_orientations) city_orientations)
...@@ -702,7 +703,9 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ ...@@ -702,7 +703,9 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
start_idx = int((nr_of_connection_points - number_of_out_rails) / 2) start_idx = int((nr_of_connection_points - number_of_out_rails) / 2)
for direction in range(4): for direction in range(4):
connection_slots = np.arange(nr_of_connection_points) - start_idx connection_slots = np.arange(nr_of_connection_points) - start_idx
inner_point_offset = np.abs(connection_slots) + np.clip(connection_slots, 0, 1) offset_distances = np.arange(nr_of_connection_points) - int(nr_of_connection_points / 2)
inner_point_offset = np.abs(offset_distances) + np.clip(offset_distances, 0, 1) + 1
for connection_idx in range(connections_per_direction[direction]): for connection_idx in range(connections_per_direction[direction]):
if direction == 0: if direction == 0:
tmp_coordinates = ( tmp_coordinates = (
...@@ -774,6 +777,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ ...@@ -774,6 +777,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
new_line = connect_rail_in_grid_map(grid_map, city_out_connection_point, neighbour_connection_point, new_line = connect_rail_in_grid_map(grid_map, city_out_connection_point, neighbour_connection_point,
rail_trans, flip_start_node_trans=False, rail_trans, flip_start_node_trans=False,
flip_end_node_trans=False, respect_transition_validity=False, flip_end_node_trans=False, respect_transition_validity=False,
avoid_rail=True,
forbidden_cells=city_cells) forbidden_cells=city_cells)
all_paths.extend(new_line) all_paths.extend(new_line)
...@@ -887,9 +891,10 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ ...@@ -887,9 +891,10 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
return agent_start_targets_cities return agent_start_targets_cities
def _fix_transitions(city_cells: IntVector2DArray, inter_city_lines: List[IntVector2DArray], def _fix_transitions(city_cells: IntVector2DArray, inter_city_lines: List[IntVector2DArray],
grid_map: GridTransitionMap, vector_field): grid_map: GridTransitionMap, vector_field, rail_trans: RailEnvTransitions, ):
""" """
Function to fix all transition elements in environment Function to fix all transition elements in environment
:param rail_trans:
:param vector_field: :param vector_field:
""" """
# Fix all cities with illegal transition maps # Fix all cities with illegal transition maps
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment