diff --git a/flatland/envs/line_generators.py b/flatland/envs/line_generators.py index 4a2ec64727cc12580c186b496f6ffed9cc1be421..254e574b3d27c272624b64ec5164f9f3ec879971 100644 --- a/flatland/envs/line_generators.py +++ b/flatland/envs/line_generators.py @@ -53,54 +53,6 @@ class BaseLineGen(object): return self.generate(*args, **kwargs) - -def complex_line_generator(speed_ratio_map: Mapping[float, float] = None, seed: int = 1) -> LineGenerator: - """ - - Generator used to generate the levels of Round 1 in the Flatland Challenge. It can only be used together - with complex_rail_generator. It places agents at end and start points provided by the rail generator. - It assigns speeds to the different agents according to the speed_ratio_map - :param speed_ratio_map: Speed ratios of all agents. They are probabilities of all different speeds and have to - add up to 1. - :param seed: Initiate random seed generator - :return: - """ - - def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0, - np_random: RandomState = None) -> Line: - """ - - The generator that assigns tasks to all the agents - :param rail: Rail infrastructure given by the rail_generator - :param num_agents: Number of agents to include in the line - :param hints: Hints provided by the rail_generator These include positions of start/target positions - :param num_resets: How often the generator has been reset. - :return: Returns the generator to the rail constructor - """ - # Todo: Remove parameters and variables not used for next version, Issue: <https://gitlab.aicrowd.com/flatland/flatland/issues/305> - _runtime_seed = seed + num_resets - - start_goal = hints['start_goal'] - start_dir = hints['start_dir'] - agents_position = [sg[0] for sg in start_goal[:num_agents]] - agents_target = [sg[1] for sg in start_goal[:num_agents]] - agents_direction = start_dir[:num_agents] - - if speed_ratio_map: - speeds = speed_initialization_helper(num_agents, speed_ratio_map, seed=_runtime_seed, np_random=np_random) - else: - speeds = [1.0] * len(agents_position) - - # Compute max number of steps with given line - extra_time_factor = 1.5 # Factor to allow for more then minimal time - max_episode_steps = int(extra_time_factor * rail.height * rail.width) - - return Line(agent_positions=agents_position, agent_directions=agents_direction, - agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None) - - return generator - - def sparse_line_generator(speed_ratio_map: Mapping[float, float] = None, seed: int = 1) -> LineGenerator: return SparseLineGen(speed_ratio_map, seed) @@ -203,111 +155,6 @@ class SparseLineGen(BaseLineGen): agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None) -def random_line_generator(speed_ratio_map: Mapping[float, float] = None, seed: int = 1) -> LineGenerator: - return RandomLineGen(speed_ratio_map, seed) - - -class RandomLineGen(BaseLineGen): - - """ - Given a `rail` GridTransitionMap, return a random placement of agents (initial position, direction and target). - - Parameters - ---------- - speed_ratio_map : Optional[Mapping[float, float]] - A map of speeds mapping to their ratio of appearance. The ratios must sum up to 1. - - Returns - ------- - Tuple[List[Tuple[int,int]], List[Tuple[int,int]], List[Tuple[int,int]], List[float]] - initial positions, directions, targets speeds - """ - - def generate(self, rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0, - np_random: RandomState = None) -> Line: - _runtime_seed = self.seed + num_resets - - valid_positions = [] - for r in range(rail.height): - for c in range(rail.width): - if rail.get_full_transitions(r, c) > 0: - valid_positions.append((r, c)) - if len(valid_positions) == 0: - return Line(agent_positions=[], agent_directions=[], - agent_targets=[], agent_speeds=[], agent_malfunction_rates=None) - - if len(valid_positions) < num_agents: - warnings.warn("line_generators: len(valid_positions) < num_agents") - return Line(agent_positions=[], agent_directions=[], - agent_targets=[], agent_speeds=[], agent_malfunction_rates=None) - - agents_position_idx = [i for i in np_random.choice(len(valid_positions), num_agents, replace=False)] - agents_position = [valid_positions[agents_position_idx[i]] for i in range(num_agents)] - agents_target_idx = [i for i in np_random.choice(len(valid_positions), num_agents, replace=False)] - agents_target = [valid_positions[agents_target_idx[i]] for i in range(num_agents)] - update_agents = np.zeros(num_agents) - - re_generate = True - cnt = 0 - while re_generate: - cnt += 1 - if cnt > 1: - print("re_generate cnt={}".format(cnt)) - if cnt > 1000: - raise Exception("After 1000 re_generates still not success, giving up.") - # update position - for i in range(num_agents): - if update_agents[i] == 1: - x = np.setdiff1d(np.arange(len(valid_positions)), agents_position_idx) - agents_position_idx[i] = np_random.choice(x) - agents_position[i] = valid_positions[agents_position_idx[i]] - x = np.setdiff1d(np.arange(len(valid_positions)), agents_target_idx) - agents_target_idx[i] = np_random.choice(x) - agents_target[i] = valid_positions[agents_target_idx[i]] - update_agents = np.zeros(num_agents) - - # agents_direction must be a direction for which a solution is - # guaranteed. - agents_direction = [0] * num_agents - re_generate = False - for i in range(num_agents): - valid_movements = [] - for direction in range(4): - position = agents_position[i] - moves = rail.get_transitions(position[0], position[1], direction) - for move_index in range(4): - if moves[move_index]: - valid_movements.append((direction, move_index)) - - valid_starting_directions = [] - for m in valid_movements: - new_position = get_new_position(agents_position[i], m[1]) - if m[0] not in valid_starting_directions and rail.check_path_exists(new_position, m[1], - agents_target[i]): - valid_starting_directions.append(m[0]) - - if len(valid_starting_directions) == 0: - update_agents[i] = 1 - warnings.warn( - "reset position for agent[{}]: {} -> {}".format(i, agents_position[i], agents_target[i])) - re_generate = True - break - else: - agents_direction[i] = valid_starting_directions[ - np_random.choice(len(valid_starting_directions), 1)[0]] - - agents_speed = speed_initialization_helper(num_agents, self.speed_ratio_map, seed=_runtime_seed, - np_random=np_random) - - # Compute max number of steps with given line - extra_time_factor = 1.5 # Factor to allow for more then minimal time - max_episode_steps = int(extra_time_factor * rail.height * rail.width) - - return Line(agent_positions=agents_position, agent_directions=agents_direction, - agent_targets=agents_target, agent_speeds=agents_speed, agent_malfunction_rates=None) - - - def line_from_file(filename, load_from_package=None) -> LineGenerator: """ Utility to load pickle file diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 9db7c8fa64455b6833b352dd731398c00208d029..b9e4ef0f65ceb385ac030f1d9e95baa00bdfaa72 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -65,148 +65,7 @@ class EmptyRailGen(RailGen): rail_array.fill(0) return grid_map, None - - - -def complex_rail_generator(nr_start_goal=1, - nr_extra=100, - min_dist=20, - max_dist=99999, - seed=1) -> RailGenerator: - """ - complex_rail_generator - - Parameters - ---------- - width : int - The width (number of cells) of the grid to generate. - height : int - The height (number of cells) of the grid to generate. - - Returns - ------- - numpy.ndarray of type numpy.uint16 - The matrix with the correct 16-bit bitmaps for each cell. - """ - - def generator(width: int, height: int, num_agents: int, num_resets: int = 0, - np_random: RandomState = None) -> RailGenerator: - - if num_agents > nr_start_goal: - num_agents = nr_start_goal - print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents") - grid_map = GridTransitionMap(width=width, height=height, transitions=RailEnvTransitions()) - rail_array = grid_map.grid - rail_array.fill(0) - - # generate rail array - # step 1: - # - generate a start and goal position - # - validate min/max distance allowed - # - validate that start/goals are not placed too close to other start/goals - # - draw a rail from [start,goal] - # - if rail crosses existing rail then validate new connection - # - possibility that this fails to create a path to goal - # - on failure generate new start/goal - # - # step 2: - # - add more rails to map randomly between cells that have rails - # - validate all new rails, on failure don't add new rails - # - # step 3: - # - return transition map + list of [start_pos, start_dir, goal_pos] points - # - - rail_trans = grid_map.transitions - start_goal = [] - start_dir = [] - nr_created = 0 - created_sanity = 0 - sanity_max = 9000 - while nr_created < nr_start_goal and created_sanity < sanity_max: - all_ok = False - for _ in range(sanity_max): - start = (np_random.randint(0, height), np_random.randint(0, width)) - goal = (np_random.randint(0, height), np_random.randint(0, width)) - - # check to make sure start,goal pos is empty? - if rail_array[goal] != 0 or rail_array[start] != 0: - continue - # check min/max distance - dist_sg = distance_on_rail(start, goal) - if dist_sg < min_dist: - continue - if dist_sg > max_dist: - continue - # check distance to existing points - sg_new = [start, goal] - - def check_all_dist(sg_new): - """ - Function to check the distance betweens start and goal - :param sg_new: start and goal tuple - :return: True if distance is larger than 2, False otherwise - """ - for sg in start_goal: - for i in range(2): - for j in range(2): - dist = distance_on_rail(sg_new[i], sg[j]) - if dist < 2: - return False - return True - - if check_all_dist(sg_new): - all_ok = True - break - - if not all_ok: - # we might as well give up at this point - break - - new_path = connect_rail_in_grid_map(grid_map, start, goal, rail_trans, Vec2d.get_chebyshev_distance, - flip_start_node_trans=True, flip_end_node_trans=True, - respect_transition_validity=True, forbidden_cells=None) - if len(new_path) >= 2: - nr_created += 1 - start_goal.append([start, goal]) - start_dir.append(mirror(get_direction(new_path[0], new_path[1]))) - else: - # after too many failures we will give up - created_sanity += 1 - - # add extra connections between existing rail - created_sanity = 0 - nr_created = 0 - while nr_created < nr_extra and created_sanity < sanity_max: - all_ok = False - for _ in range(sanity_max): - start = (np_random.randint(0, height), np_random.randint(0, width)) - goal = (np_random.randint(0, height), np_random.randint(0, width)) - # check to make sure start,goal pos are not empty - if rail_array[goal] == 0 or rail_array[start] == 0: - continue - else: - all_ok = True - break - if not all_ok: - break - new_path = connect_rail_in_grid_map(grid_map, start, goal, rail_trans, Vec2d.get_chebyshev_distance, - flip_start_node_trans=True, flip_end_node_trans=True, - respect_transition_validity=True, forbidden_cells=None) - - if len(new_path) >= 2: - nr_created += 1 - else: - # after too many failures we will give up - created_sanity += 1 - - return grid_map, {'agents_hints': { - 'start_goal': start_goal, - 'start_dir': start_dir - }} - - return generator - + def rail_from_manual_specifications_generator(rail_spec): """ @@ -319,289 +178,6 @@ def rail_from_grid_transition_map_old(rail_map) -> RailGenerator: return generator -def random_rail_generator(cell_type_relative_proportion=[1.0] * 11, seed=1) -> RailGenerator: - """ - Dummy random level generator: - - fill in cells at random in [width-2, height-2] - - keep filling cells in among the unfilled ones, such that all transitions\ - are legit; if no cell can be filled in without violating some\ - transitions, pick one among those that can satisfy most transitions\ - (1,2,3 or 4), and delete (+mark to be re-filled) the cells that were\ - incompatible. - - keep trying for a total number of insertions\ - (e.g., (W-2)*(H-2)*MAX_REPETITIONS ); if no solution is found, empty the\ - board and try again from scratch. - - finally pad the border of the map with dead-ends to avoid border issues. - - Dead-ends are not allowed inside the grid, only at the border; however, if - no cell type can be inserted in a given cell (because of the neighboring - transitions), deadends are allowed if they solve the problem. This was - found to turn most un-genereatable levels into valid ones. - - Parameters - ---------- - width : int - The width (number of cells) of the grid to generate. - height : int - The height (number of cells) of the grid to generate. - - Returns - ------- - numpy.ndarray of type numpy.uint16 - The matrix with the correct 16-bit bitmaps for each cell. - """ - - def generator(width: int, height: int, num_agents: int, num_resets: int = 0, - np_random: RandomState = None) -> RailGenerator: - t_utils = RailEnvTransitions() - - transition_probability = cell_type_relative_proportion - - transitions_templates_ = [] - transition_probabilities = [] - for i in range(len(t_utils.transitions)): # don't include dead-ends - if t_utils.transitions[i] == int('0010000000000000', 2): - continue - - all_transitions = 0 - for dir_ in range(4): - trans = t_utils.get_transitions(t_utils.transitions[i], dir_) - all_transitions |= (trans[0] << 3) | \ - (trans[1] << 2) | \ - (trans[2] << 1) | \ - (trans[3]) - - template = [int(x) for x in bin(all_transitions)[2:]] - template = [0] * (4 - len(template)) + template - - # add all rotations - for rot in [0, 90, 180, 270]: - transitions_templates_.append((template, - t_utils.rotate_transition( - t_utils.transitions[i], - rot))) - transition_probabilities.append(transition_probability[i]) - template = [template[-1]] + template[:-1] - - def get_matching_templates(template): - """ - Returns a list of possible transition maps for a given template - - Parameters: - ------ - template:List[int] - - Returns: - ------ - List[int] - """ - ret = [] - for i in range(len(transitions_templates_)): - is_match = True - for j in range(4): - if template[j] >= 0 and template[j] != transitions_templates_[i][0][j]: - is_match = False - break - if is_match: - ret.append((transitions_templates_[i][1], transition_probabilities[i])) - return ret - - MAX_INSERTIONS = (width - 2) * (height - 2) * 10 - MAX_ATTEMPTS_FROM_SCRATCH = 10 - - attempt_number = 0 - while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH: - cells_to_fill = [] - rail = [] - for r in range(height): - rail.append([None] * width) - if r > 0 and r < height - 1: - cells_to_fill = cells_to_fill + [(r, c) for c in range(1, width - 1)] - - num_insertions = 0 - while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0: - cell = cells_to_fill[np_random.choice(len(cells_to_fill), 1)[0]] - cells_to_fill.remove(cell) - row = cell[0] - col = cell[1] - - # look at its neighbors and see what are the possible transitions - # that can be chosen from, if any. - valid_template = [-1, -1, -1, -1] - - for el in [(0, 2, (-1, 0)), - (1, 3, (0, 1)), - (2, 0, (1, 0)), - (3, 1, (0, -1))]: # N, E, S, W - neigh_trans = rail[row + el[2][0]][col + el[2][1]] - if neigh_trans is not None: - # select transition coming from facing direction el[1] and - # moving to direction el[1] - max_bit = 0 - for k in range(4): - max_bit |= t_utils.get_transition(neigh_trans, k, el[1]) - - if max_bit: - valid_template[el[0]] = 1 - else: - valid_template[el[0]] = 0 - - possible_cell_transitions = get_matching_templates(valid_template) - - if len(possible_cell_transitions) == 0: # NO VALID TRANSITIONS - # no cell can be filled in without violating some transitions - # can a dead-end solve the problem? - if valid_template.count(1) == 1: - for k in range(4): - if valid_template[k] == 1: - rot = 0 - if k == 0: - rot = 180 - elif k == 1: - rot = 270 - elif k == 2: - rot = 0 - elif k == 3: - rot = 90 - - rail[row][col] = t_utils.rotate_transition(int('0010000000000000', 2), rot) - num_insertions += 1 - - break - - else: - # can I get valid transitions by removing a single - # neighboring cell? - bestk = -1 - besttrans = [] - for k in range(4): - tmp_template = valid_template[:] - tmp_template[k] = -1 - possible_cell_transitions = get_matching_templates(tmp_template) - if len(possible_cell_transitions) > len(besttrans): - besttrans = possible_cell_transitions - bestk = k - - if bestk >= 0: - # Replace the corresponding cell with None, append it - # to cells to fill, fill in a transition in the current - # cell. - replace_row = row - 1 - replace_col = col - if bestk == 1: - replace_row = row - replace_col = col + 1 - elif bestk == 2: - replace_row = row + 1 - replace_col = col - elif bestk == 3: - replace_row = row - replace_col = col - 1 - - cells_to_fill.append((replace_row, replace_col)) - rail[replace_row][replace_col] = None - - possible_transitions, possible_probabilities = zip(*besttrans) - possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities] - - rail[row][col] = np_random.choice(possible_transitions, - p=possible_probabilities) - num_insertions += 1 - - else: - print('WARNING: still nothing!') - rail[row][col] = int('0000000000000000', 2) - num_insertions += 1 - pass - - else: - possible_transitions, possible_probabilities = zip(*possible_cell_transitions) - possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities] - - rail[row][col] = np_random.choice(possible_transitions, - p=possible_probabilities) - num_insertions += 1 - - if num_insertions == MAX_INSERTIONS: - # Failed to generate a valid level; try again for a number of times - attempt_number += 1 - else: - break - - if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH: - print('ERROR: failed to generate level') - - # Finally pad the border of the map with dead-ends to avoid border issues; - # at most 1 transition in the neigh cell - for r in range(height): - # Check for transitions coming from [r][1] to WEST - max_bit = 0 - neigh_trans = rail[r][1] - if neigh_trans is not None: - for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1) - max_bit = max_bit | (neigh_trans_from_direction & 1) - if max_bit: - rail[r][0] = t_utils.rotate_transition(int('0010000000000000', 2), 270) - else: - rail[r][0] = int('0000000000000000', 2) - - # Check for transitions coming from [r][-2] to EAST - max_bit = 0 - neigh_trans = rail[r][-2] - if neigh_trans is not None: - for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1) - max_bit = max_bit | (neigh_trans_from_direction & (1 << 2)) - if max_bit: - rail[r][-1] = t_utils.rotate_transition(int('0010000000000000', 2), - 90) - else: - rail[r][-1] = int('0000000000000000', 2) - - for c in range(width): - # Check for transitions coming from [1][c] to NORTH - max_bit = 0 - neigh_trans = rail[1][c] - if neigh_trans is not None: - for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1) - max_bit = max_bit | (neigh_trans_from_direction & (1 << 3)) - if max_bit: - rail[0][c] = int('0010000000000000', 2) - else: - rail[0][c] = int('0000000000000000', 2) - - # Check for transitions coming from [-2][c] to SOUTH - max_bit = 0 - neigh_trans = rail[-2][c] - if neigh_trans is not None: - for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1) - max_bit = max_bit | (neigh_trans_from_direction & (1 << 1)) - if max_bit: - rail[-1][c] = t_utils.rotate_transition(int('0010000000000000', 2), 180) - else: - rail[-1][c] = int('0000000000000000', 2) - - # For display only, wrong levels - for r in range(height): - for c in range(width): - if rail[r][c] is None: - rail[r][c] = int('0000000000000000', 2) - - tmp_rail = np.asarray(rail, dtype=np.uint16) - - return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils) - return_rail.grid = tmp_rail - - return return_rail, None - - return generator - - - - def sparse_rail_generator(*args, **kwargs): return SparseRailGen(*args, **kwargs) diff --git a/tests/test_flatland_line_from_file.py b/tests/test_flatland_line_from_file.py index b324af98a8550a47621be6258d2d559f128f45a5..e7e9738e3624dce4d843d7c0345807cbd6708159 100644 --- a/tests/test_flatland_line_from_file.py +++ b/tests/test_flatland_line_from_file.py @@ -1,10 +1,8 @@ from test_utils import create_and_save_env from flatland.envs.rail_env import RailEnv -from flatland.envs.rail_generators import sparse_rail_generator, random_rail_generator, complex_rail_generator, \ - rail_from_file -from flatland.envs.line_generators import sparse_line_generator, random_line_generator, \ - complex_line_generator, line_from_file +from flatland.envs.rail_generators import sparse_rail_generator, rail_from_file +from flatland.envs.line_generators import sparse_line_generator, line_from_file def test_line_from_file_sparse(): @@ -46,86 +44,4 @@ def test_line_from_file_sparse(): assert sparse_env_from_file.get_num_agents() == old_num_agents # Assert max steps is correct - assert sparse_env_from_file._max_episode_steps == old_num_steps - - - -def test_line_from_file_random(): - """ - Test to see that all parameters are loaded as expected - Returns - ------- - - """ - # Different agent types (trains) with different speeds. - speed_ration_map = {1.: 0.25, # Fast passenger train - 1. / 2.: 0.25, # Fast freight train - 1. / 3.: 0.25, # Slow commuter train - 1. / 4.: 0.25} # Slow freight train - - # Generate random test env - rail_generator = random_rail_generator() - line_generator = random_line_generator(speed_ration_map) - - env = create_and_save_env(file_name="./random_env_test.pkl", rail_generator=rail_generator, - line_generator=line_generator) - old_num_steps = env._max_episode_steps - old_num_agents = len(env.agents) - - - # Random generator - rail_generator = rail_from_file("./random_env_test.pkl") - line_generator = line_from_file("./random_env_test.pkl") - random_env_from_file = RailEnv(width=1, height=1, rail_generator=rail_generator, - line_generator=line_generator) - random_env_from_file.reset(True, True) - - # Assert loaded agent number is correct - assert random_env_from_file.get_num_agents() == old_num_agents - - # Assert max steps is correct - assert random_env_from_file._max_episode_steps == old_num_steps - - - - -def test_line_from_file_complex(): - """ - Test to see that all parameters are loaded as expected - Returns - ------- - - """ - # Different agent types (trains) with different speeds. - speed_ration_map = {1.: 0.25, # Fast passenger train - 1. / 2.: 0.25, # Fast freight train - 1. / 3.: 0.25, # Slow commuter train - 1. / 4.: 0.25} # Slow freight train - - # Generate complex test env - rail_generator = complex_rail_generator(nr_start_goal=10, - nr_extra=1, - min_dist=8, - max_dist=99999) - line_generator = complex_line_generator(speed_ration_map) - - env = create_and_save_env(file_name="./complex_env_test.pkl", rail_generator=rail_generator, - line_generator=line_generator) - old_num_steps = env._max_episode_steps - old_num_agents = len(env.agents) - - # Load the different envs and check the parameters - - - # Complex generator - rail_generator = rail_from_file("./complex_env_test.pkl") - line_generator = line_from_file("./complex_env_test.pkl") - complex_env_from_file = RailEnv(width=1, height=1, rail_generator=rail_generator, - line_generator=line_generator) - complex_env_from_file.reset(True, True) - - # Assert loaded agent number is correct - assert complex_env_from_file.get_num_agents() == old_num_agents - - # Assert max steps is correct - assert complex_env_from_file._max_episode_steps == old_num_steps + assert sparse_env_from_file._max_episode_steps == old_num_steps \ No newline at end of file diff --git a/tests/test_generators.py b/tests/test_generators.py index 61cf9523fb3f4e53dd29ac440960da8f866ad230..b58836051977ef7215c1e683b3c2a4d61feaffa4 100644 --- a/tests/test_generators.py +++ b/tests/test_generators.py @@ -6,16 +6,14 @@ import numpy as np from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv -from flatland.envs.rail_generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \ - random_rail_generator, empty_rail_generator -from flatland.envs.line_generators import random_line_generator, complex_line_generator, \ - line_from_file +from flatland.envs.rail_generators import rail_from_grid_transition_map, rail_from_file, empty_rail_generator +from flatland.envs.line_generators import sparse_line_generator, line_from_file from flatland.utils.simple_rail import make_simple_rail from flatland.envs.persistence import RailEnvPersister def test_empty_rail_generator(): - n_agents = 1 + n_agents = 2 x_dim = 5 y_dim = 10 @@ -30,61 +28,11 @@ def test_empty_rail_generator(): assert env.get_num_agents() == 0 -def test_random_rail_generator(): - n_agents = 1 - x_dim = 5 - y_dim = 10 - - # Check that a random level at with correct parameters is generated - env = RailEnv(width=x_dim, height=y_dim, rail_generator=random_rail_generator(), number_of_agents=n_agents) - env.reset() - assert env.rail.grid.shape == (y_dim, x_dim) - assert env.get_num_agents() == n_agents - - -def test_complex_rail_generator(): - n_agents = 10 - n_start = 2 - x_dim = 10 - y_dim = 10 - min_dist = 4 - - # Check that agent number is changed to fit generated level - env = RailEnv(width=x_dim, height=y_dim, - rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist), - line_generator=complex_line_generator(), number_of_agents=n_agents) - env.reset() - assert env.get_num_agents() == 2 - assert env.rail.grid.shape == (y_dim, x_dim) - - min_dist = 2 * x_dim - - # Check that no agents are generated when level cannot be generated - env = RailEnv(width=x_dim, height=y_dim, - rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist), - line_generator=complex_line_generator(), number_of_agents=n_agents) - env.reset() - assert env.get_num_agents() == 0 - assert env.rail.grid.shape == (y_dim, x_dim) - - # Check that everything stays the same when correct parameters are given - min_dist = 2 - n_start = 5 - n_agents = 5 - - env = RailEnv(width=x_dim, height=y_dim, - rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist), - line_generator=complex_line_generator(), number_of_agents=n_agents) - env.reset() - assert env.get_num_agents() == n_agents - assert env.rail.grid.shape == (y_dim, x_dim) - - def test_rail_from_grid_transition_map(): rail, rail_map = make_simple_rail() - n_agents = 3 + n_agents = 4 env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - line_generator=random_line_generator(), number_of_agents=n_agents) + line_generator=sparse_line_generator(), number_of_agents=n_agents) env.reset(False, False, True) nr_rail_elements = np.count_nonzero(env.rail.grid) @@ -106,7 +54,7 @@ def tests_rail_from_file(): rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - line_generator=random_line_generator(), number_of_agents=3, + line_generator=sparse_line_generator(), number_of_agents=3, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() #env.save(file_name) @@ -134,7 +82,7 @@ def tests_rail_from_file(): file_name_2 = "test_without_distance_map.pkl" env2 = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), line_generator=random_line_generator(), + rail_generator=rail_from_grid_transition_map(rail), line_generator=sparse_line_generator(), number_of_agents=3, obs_builder_object=GlobalObsForRailEnv()) env2.reset() #env2.save(file_name_2) diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index 2664c5b4f8b18004b4a39095c16acee753832eb7..ad5d2e5cb28c67ecbaf956c776258a20e9a6dc44 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -4,13 +4,13 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions -from flatland.envs.rail_generators import complex_rail_generator, rail_from_grid_transition_map -from flatland.envs.line_generators import complex_line_generator, random_line_generator +from flatland.envs.rail_generators import sparse_rail_generator, rail_from_grid_transition_map +from flatland.envs.line_generators import sparse_line_generator from flatland.utils.simple_rail import make_simple_rail from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay -# Use the complex_rail_generator to generate feasible network configurations with corresponding tasks +# Use the sparse_rail_generator to generate feasible network configurations with corresponding tasks # Training on simple small tasks is the best way to get familiar with the environment # @@ -48,9 +48,8 @@ class RandomAgent: def test_multi_speed_init(): env = RailEnv(width=50, height=50, - rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, - seed=1), line_generator=complex_line_generator(), - number_of_agents=5) + rail_generator=sparse_rail_generator(seed=1), line_generator=sparse_line_generator(), + number_of_agents=6) # Initialize the agent with the parameters corresponding to the environment and observation_builder agent = RandomAgent(218, 4) @@ -95,7 +94,7 @@ def test_multispeed_actions_no_malfunction_no_blocking(): """Test that actions are correctly performed on cell exit for a single agent.""" rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - line_generator=random_line_generator(), number_of_agents=1, + line_generator=sparse_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() @@ -195,7 +194,7 @@ def test_multispeed_actions_no_malfunction_blocking(): """The second agent blocks the first because it is slower.""" rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - line_generator=random_line_generator(), number_of_agents=2, + line_generator=sparse_line_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() @@ -385,7 +384,7 @@ def test_multispeed_actions_malfunction_no_blocking(): """Test on a single agent whether action on cell exit work correctly despite malfunction.""" rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - line_generator=random_line_generator(), number_of_agents=1, + line_generator=sparse_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() @@ -523,7 +522,7 @@ def test_multispeed_actions_no_malfunction_invalid_actions(): """Test that actions are correctly performed on cell exit for a single agent.""" rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - line_generator=random_line_generator(), number_of_agents=1, + line_generator=sparse_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() diff --git a/tests/test_speed_classes.py b/tests/test_speed_classes.py index 008992ee20fb626bea272d1cb70d73c456ff6cff..3cfe1b1c7f58786cf0caacde629fa3a6c704230d 100644 --- a/tests/test_speed_classes.py +++ b/tests/test_speed_classes.py @@ -2,8 +2,8 @@ import numpy as np from flatland.envs.rail_env import RailEnv -from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.line_generators import speed_initialization_helper, complex_line_generator +from flatland.envs.rail_generators import sparse_rail_generator +from flatland.envs.line_generators import speed_initialization_helper, sparse_line_generator def test_speed_initialization_helper(): @@ -20,8 +20,7 @@ def test_rail_env_speed_intializer(): speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.1, 5: 0.2} env = RailEnv(width=50, height=50, - rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, - seed=1), line_generator=complex_line_generator(), + rail_generator=sparse_rail_generator(), line_generator=sparse_line_generator(), number_of_agents=10) env.reset() actual_speeds = list(map(lambda agent: agent.speed_data['speed'], env.agents))