diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 5c92fd1cde5024b330bc10050203d74cd0d74af5..6396f4123adf895c381c2a21f6d8dc6e4e823b92 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -1,3 +1,5 @@ +import time + import numpy as np from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv @@ -30,14 +32,14 @@ speed_ration_map = {1.: 0.25, # Fast passenger train env = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(num_cities=20, # Number of cities in map (where train stations are) + rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map (where train stations are) seed=1, # Random seed grid_mode=False, max_inter_city_rails=2, max_tracks_in_city=4, ), schedule_generator=sparse_schedule_generator(), - number_of_agents=50, + number_of_agents=10, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=GlobalObsForRailEnv()) @@ -111,6 +113,7 @@ for step in range(500): # reward and whether their are done next_obs, all_rewards, done, _ = env.step(action_dict) env_renderer.render_env(show=True, show_observations=False, show_predictions=False) + time.sleep(100) frame_step += 1 # Update replay buffer and train agent for a in range(env.get_num_agents()): diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index d0e75e27418cfb179ae907f4a3036d6d124618f4..089cea5e9f3694e92d9f84fec68b935592e7d537 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -5,6 +5,8 @@ Generator functions are functions that take width, height and num_resets as argu a GridTransitionMap object. """ +import numpy as np + from flatland.core.grid.grid4_astar import a_star from flatland.core.grid.grid4_utils import get_direction, mirror from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance, IntVector2DArray @@ -77,6 +79,39 @@ def connect_basic_operation( return path +def connect_line(rail_trans, grid_map, start, end, openend=False): + # Set start cell + current_cell = start + path = [current_cell] + new_trans = grid_map.grid[current_cell] + direction = (np.clip(end[0] - start[0], -1, 1), np.clip(end[1] - start[1], -1, 1)) + if direction[0] == 0: + if direction[1] > 0: + direction_int = 1 + else: + direction_int = 3 + else: + if direction[0] > 0: + direction_int = 2 + else: + direction_int = 0 + new_trans = rail_trans.set_transition(new_trans, direction_int, direction_int, 1) + new_trans = rail_trans.set_transition(new_trans, mirror(direction_int), mirror(direction_int), 1) + grid_map.grid[current_cell] = new_trans + if openend: + grid_map.grid[current_cell] = 0 + # Set path + while current_cell != end: + current_cell = tuple(map(lambda x, y: x + y, current_cell, direction)) + new_trans = grid_map.grid[current_cell] + new_trans = rail_trans.set_transition(new_trans, direction_int, direction_int, 1) + new_trans = rail_trans.set_transition(new_trans, mirror(direction_int), mirror(direction_int), 1) + grid_map.grid[current_cell] = new_trans + if current_cell == end and openend: + grid_map.grid[current_cell] = 0 + path.append(current_cell) + return path + def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D, a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: @@ -106,3 +141,8 @@ def connect_to_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap start: IntVector2D, end: IntVector2D, a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: return connect_basic_operation(rail_trans, grid_map, start, end, True, False, a_star_distance_function) + + +def connect_straigt_line(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, + end: IntVector2D, openend=False) -> IntVector2DArray: + return connect_line(rail_trans, grid_map, start, end, openend) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index d6e14459215fbcf2ea8086b3ef82833384e45b4a..6cfbece9a3203772e05ddb08dff51575f4f012d4 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -10,7 +10,7 @@ from flatland.core.grid.grid4_utils import get_direction, mirror from flatland.core.grid.grid_utils import distance_on_rail, direction_to_point from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap -from flatland.envs.grid4_generators_utils import connect_rail, connect_cities +from flatland.envs.grid4_generators_utils import connect_rail, connect_cities, connect_straigt_line RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]] RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct] @@ -552,7 +552,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, rail_array = grid_map.grid rail_array.fill(0) np.random.seed(seed + num_resets) - node_radius = int(np.ceil((max_tracks_in_city + 2) / 2.0)) + 2 + node_radius = int(np.ceil((max_tracks_in_city + 2) / 2.0)) + 1 max_inter_city_rails_allowed = max_inter_city_rails if max_inter_city_rails_allowed > max_tracks_in_city: max_inter_city_rails_allowed = max_tracks_in_city @@ -604,7 +604,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, # Fix all transition elements grid_fix_time = time.time() - _fix_transitions(grid_map) + _fix_transitions(city_cells, grid_map) print("Grid fix time", time.time() - grid_fix_time) # Generate start target pairs @@ -791,14 +791,17 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, opposite_boarder = (boarder + 2) % 4 boarder_one = inner_connection_points[current_city][boarder] boarder_two = inner_connection_points[current_city][opposite_boarder] - connect_cities(rail_trans, grid_map, boarder_one[0], boarder_one[-1]) - connect_cities(rail_trans, grid_map, boarder_two[0], boarder_two[-1]) + # Connect the ends of the tracks + connect_straigt_line(rail_trans, grid_map, boarder_one[0], boarder_one[-1], False) + connect_straigt_line(rail_trans, grid_map, boarder_two[0], boarder_two[-1], False) + + # Connect parallel tracks for track_id in range(len(inner_connection_points[current_city][boarder])): if track_id % 2 == 0: source = inner_connection_points[current_city][boarder][track_id] target = inner_connection_points[current_city][opposite_boarder][track_id] - current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder) + current_track = connect_straigt_line(rail_trans, grid_map, source, target) if target in all_outer_connection_points and source in \ all_outer_connection_points and len(through_path_cells[current_city]) < 1: through_path_cells[current_city].extend(current_track) @@ -806,7 +809,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, source = inner_connection_points[current_city][opposite_boarder][track_id] target = inner_connection_points[current_city][boarder][track_id] - current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder) + current_track = connect_straigt_line(rail_trans, grid_map, source, target) if target in all_outer_connection_points and source in \ all_outer_connection_points and len(through_path_cells[current_city]) < 1: through_path_cells[current_city].extend(current_track) @@ -888,23 +891,22 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, num_agents -= 1 return agent_start_targets_nodes, num_agents - def _fix_transitions(grid_map): + def _fix_transitions(city_cells, grid_map): """ Function to fix all transition elements in environment """ # Fix all nodes with illegal transition maps empty_to_fix = [] rails_to_fix = [] - height, width = np.shape(grid_map.grid) - for r in range(height): - for c in range(width): - rc_pos = (r, c) - check = grid_map.cell_neighbours_valid(rc_pos, True) - if not check: - if grid_map.grid[rc_pos] == 0: - empty_to_fix.append(rc_pos) - else: - rails_to_fix.append(rc_pos) + for cell in city_cells: + check = grid_map.cell_neighbours_valid(cell, True) + if grid_map.grid[cell] == int('1000010000100001', 2): + grid_map.fix_transitions(cell) + if not check: + if grid_map.grid[cell] == 0: + empty_to_fix.append(cell) + else: + rails_to_fix.append(cell) # Fix empty cells first to avoid cutting the network for cell in empty_to_fix: