diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index b9ace9f537aea14c233d55aeaf814e9620a8fbd0..5ece03e9c56d672b76a453e0036f6b89c3a6ee77 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -37,7 +37,7 @@ env = RailEnv(width=100, seed=14, # Random seed grid_mode=False, max_rails_between_cities=2, - max_rails_in_city=6, + max_rails_in_city=8, ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=100, diff --git a/flatland/envs/distance_map.py b/flatland/envs/distance_map.py index 2bc1a5117794959cca82d2edad821cb629397f78..c6e73b0bdbe752b8d5df9c4a0697bb621e5276ec 100644 --- a/flatland/envs/distance_map.py +++ b/flatland/envs/distance_map.py @@ -55,6 +55,7 @@ class DistanceMap: self.env_width = rail.width def _compute(self, agents: List[EnvAgent], rail: GridTransitionMap): + print("computing distance map") self.agents_previous_computation = self.agents self.distance_map = np.inf * np.ones(shape=(len(agents), self.env_height, diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 76ac04f15aec26e84c6d2b1df79f3a70b27a51f9..aa8e48023694c96941b59d6f78b3fd93edf81e9e 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -2,13 +2,13 @@ Definition of the RailEnv environment. """ # TODO: _ this is a global method --> utils or remove later -import warnings from enum import IntEnum -from typing import List, NamedTuple, Optional, Tuple, Dict +from typing import List, NamedTuple, Optional, Dict import msgpack import msgpack_numpy as m import numpy as np +from gym.utils import seeding from flatland.core.env import Environment from flatland.core.env_observation_builder import ObservationBuilder @@ -17,7 +17,7 @@ from flatland.core.grid.grid4_utils import get_new_position from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus from flatland.envs.distance_map import DistanceMap -from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_generators import random_rail_generator, RailGenerator from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator @@ -114,10 +114,11 @@ class RailEnv(Environment): rail_generator: RailGenerator = random_rail_generator(), schedule_generator: ScheduleGenerator = random_schedule_generator(), number_of_agents=1, - obs_builder_object: ObservationBuilder = TreeObsForRailEnv(max_depth=2), + obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(), max_episode_steps=None, stochastic_data=None, - remove_agents_at_target=False + remove_agents_at_target=False, + random_seed=None ): """ Environment init. @@ -151,6 +152,9 @@ class RailEnv(Environment): remove_agents_at_target : bool If remove_agents_at_target is set to true then the agents will be removed by placing to RailEnv.DEPOT_POSITION when the agent has reach it's target position. + random_seed : int or None + if None, then its ignored, else the random generators are seeded with this number to ensure + that stochastic operations are replicable across multiple operations """ super().__init__() @@ -184,6 +188,13 @@ class RailEnv(Environment): self.distance_map = DistanceMap(self.agents, self.height, self.width) self.action_space = [1] + + self._seed() + + self._seed() + self.random_seed = random_seed + if self.random_seed: + self._seed(seed=random_seed) # Stochastic train malfunctioning parameters if stochastic_data is not None: @@ -212,6 +223,10 @@ class RailEnv(Environment): self.valid_positions = None + def _seed(self, seed=None): + self.np_random, seed = seeding.np_random(seed) + return [seed] + # no more agent_handles def get_agent_handles(self): return range(self.get_num_agents()) @@ -240,26 +255,30 @@ class RailEnv(Environment): """ self.agents = EnvAgent.list_from_static(self.agents_static) - def reset(self, regen_rail=True, replace_agents=True, activate_agents=False): + def reset(self, regen_rail=True, replace_agents=True, activate_agents=False, random_seed=None): """ if regen_rail then regenerate the rails. if replace_agents then regenerate the agents static. Relies on the rail_generator returning agent_static lists (pos, dir, target) """ + if random_seed: + self._seed(random_seed) - # TODO https://gitlab.aicrowd.com/flatland/flatland/issues/172 - # can we not put 'self.rail_generator(..)' into 'if regen_rail or self.rail is None' condition? - rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets) - + optionals = {} if regen_rail or self.rail is None: + rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets) + self.rail = rail self.height, self.width = self.rail.grid.shape - for r in range(self.height): - for c in range(self.width): - rc_pos = (r, c) - check = self.rail.cell_neighbours_valid(rc_pos, True) - if not check: - print(self.rail.grid[rc_pos]) - warnings.warn("Invalid grid at {} -> {}".format(rc_pos, check)) + # NOTE : Ignore Validation on every reset. rail_generator should ensure that + # only valid grids are generated. + # + # for r in range(self.height): + # for c in range(self.width): + # rc_pos = (r, c) + # check = self.rail.cell_neighbours_valid(rc_pos, True) + # if not check: + # print(self.rail.grid[rc_pos]) + # warnings.warn("Invalid grid at {} -> {}".format(rc_pos, check)) # TODO https://gitlab.aicrowd.com/flatland/flatland/issues/172 # hacky: we must re-compute the distance map and not use the initial distance_map loaded from file by # rail_from_file!!! @@ -274,7 +293,7 @@ class RailEnv(Environment): # TODO https://gitlab.aicrowd.com/flatland/flatland/issues/185 # why do we need static agents? could we it more elegantly? self.agents_static = EnvAgentStatic.from_lists( - *self.schedule_generator(self.rail, self.get_num_agents(), agents_hints)) + *self.schedule_generator(self.rail, self.get_num_agents(), agents_hints, self.num_resets)) self.restart_agents() @@ -287,7 +306,7 @@ class RailEnv(Environment): # continue # A proportion of agent in the environment will receive a positive malfunction rate - if np.random.random() < self.proportion_malfunctioning_trains: + if self.np_random.rand() < self.proportion_malfunctioning_trains: agent.malfunction_data['malfunction_rate'] = self.mean_malfunction_rate agent.malfunction_data['malfunction'] = 0 @@ -335,17 +354,17 @@ class RailEnv(Environment): # If counter has come to zero --> Agent has malfunction # set next malfunction time and duration of current malfunction if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data['malfunction'] and \ - agent.malfunction_data['next_malfunction'] <= 0: + agent.malfunction_data['next_malfunction'] <= 0: # Increase number of malfunctions agent.malfunction_data['nr_malfunctions'] += 1 # Next malfunction in number of stops next_breakdown = int( - np.random.exponential(scale=agent.malfunction_data['malfunction_rate'])) + self.np_random.exponential(scale=agent.malfunction_data['malfunction_rate'])) agent.malfunction_data['next_malfunction'] = next_breakdown # Duration of current malfunction - num_broken_steps = np.random.randint(self.min_number_of_steps_broken, + num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken, self.max_number_of_steps_broken + 1) + 1 agent.malfunction_data['malfunction'] = num_broken_steps agent.malfunction_data['moving_before_malfunction'] = agent.moving @@ -405,7 +424,8 @@ class RailEnv(Environment): info_dict = { 'action_required': { i: (agent.status == RailAgentStatus.READY_TO_DEPART or ( - agent.status == RailAgentStatus.ACTIVE and agent.speed_data['position_fraction'] == 0.0)) + agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0, + rtol=1e-03))) for i, agent in enumerate(self.agents)}, 'malfunction': { i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents()) @@ -457,8 +477,9 @@ class RailEnv(Environment): return # Is the agent at the beginning of the cell? Then, it can take an action. - # As long as the agent is malfunctioning or stopped at the beginning of the cell, different actions may be taken! - if agent.speed_data['position_fraction'] == 0.0: + # As long as the agent is malfunctioning or stopped at the beginning of the cell, + # different actions may be taken! + if np.isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03): # No action has been supplied for this agent -> set DO_NOTHING as default if action is None: action = RailEnvActions.DO_NOTHING @@ -479,7 +500,7 @@ class RailEnv(Environment): self.rewards_dict[i_agent] += self.stop_penalty if not agent.moving and not ( - action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING): + action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING): # Allow agent to start with any forward or direction action agent.moving = True self.rewards_dict[i_agent] += self.start_penalty diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 1a046f74d39ca8450f88b043e9168de03c269e1e..242e11d73797b2dfdcea0efe7a079dcc066ff05e 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -272,7 +272,7 @@ def rail_from_grid_transition_map(rail_map) -> RailGenerator: return generator -def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGenerator: +def random_rail_generator(cell_type_relative_proportion=[1.0] * 11, seed=0) -> RailGenerator: """ Dummy random level generator: - fill in cells at random in [width-2, height-2] @@ -305,6 +305,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGener """ def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct: + np.random.seed(seed + num_resets) t_utils = RailEnvTransitions() transition_probability = cell_type_relative_proportion @@ -542,7 +543,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGener def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_rails_between_cities: int = 4, - max_rails_in_city: int = 4, seed: int = 0) -> RailGenerator: + max_rails_in_city: int = 4, seed: int = 1) -> RailGenerator: """ Generates railway networks with cities and inner city rails :param max_num_cities: Number of city centers in the map @@ -562,7 +563,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ vector_field = np.zeros(shape=(height, width)) - 1. min_nr_rails_in_city = 2 - max_nr_rail_in_city = 6 + # max_nr_rail_in_city = 6 rails_in_city = min_nr_rails_in_city if max_rails_in_city < min_nr_rails_in_city else max_rails_in_city rails_between_cities = rails_in_city if max_rails_between_cities > rails_in_city else max_rails_between_cities @@ -588,9 +589,10 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ warnings.warn("Initial parameters cannot generate valid railway") return # Set up connection points for all cities - inner_connection_points, outer_connection_points, connection_info, city_orientations = _generate_city_connection_points( - city_positions, city_radius, rails_between_cities, - rails_in_city) + inner_connection_points, outer_connection_points, connection_info, city_orientations = \ + _generate_city_connection_points( + city_positions, city_radius, rails_between_cities, + rails_in_city) # Connect the cities through the connection points inter_city_lines = _connect_cities(city_positions, outer_connection_points, city_cells, @@ -616,8 +618,8 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ 'city_orientations': city_orientations }} - def _generate_random_city_positions(num_cities: int, city_radius: int, width: int, height: int, vector_field) -> ( - IntVector2DArray, IntVector2DArray): + def _generate_random_city_positions(num_cities: int, city_radius: int, width: int, + height: int, vector_field) -> (IntVector2DArray, IntVector2DArray): city_positions: IntVector2DArray = [] city_cells: IntVector2DArray = [] for city_idx in range(num_cities): @@ -640,7 +642,8 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ tries += 1 if tries > 200: warnings.warn( - "Could not only set {} cities after {} tries, although {} of cities required to be generated!".format( + "Could only set {} cities after {} tries, although {} of cities required to be generated!".format( + # noqa len(city_positions), tries, num_cities)) break @@ -650,9 +653,9 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ vector_field) -> (IntVector2DArray, IntVector2DArray): aspect_ratio = height / width cities_per_row = min(int(np.ceil(np.sqrt(num_cities * aspect_ratio))), - int((height - 2) / (2 * city_radius + 1))) + int((height - 2) / (2 * (city_radius + 1)))) cities_per_col = min(int(np.ceil(num_cities / cities_per_row)), - int((width - 2) / (2 * city_radius + 1))) + int((width - 2) / (2 * (city_radius + 1)))) num_build_cities = min(num_cities, cities_per_col * cities_per_row) row_positions = np.linspace(city_radius + 1, height - 2 * (city_radius + 1), cities_per_row, dtype=int) col_positions = np.linspace(city_radius + 1, width - 2 * (city_radius + 1), cities_per_col, dtype=int) @@ -829,6 +832,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ source = inner_connection_points[current_city][boarder][track_id] target = inner_connection_points[current_city][opposite_boarder][track_id] current_track = connect_straight_line_in_grid_map(grid_map, source, target, rail_trans) + free_rails[current_city].append(current_track) for track_id in range(nr_of_connection_points): source = inner_connection_points[current_city][boarder][track_id] @@ -878,10 +882,10 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ # Assign agents to slots for agent_idx in range(num_agents): avail_start_cities = [idx for idx, val in enumerate(city_available_start) if val > 0] - avail_target_cities = [idx for idx, val in enumerate(city_available_target) if val > 0] + # avail_target_cities = [idx for idx, val in enumerate(city_available_target) if val > 0] # Set probability to choose start and stop from trainstations sum_start = sum(np.array(city_available_start)[avail_start_cities]) - sum_target = sum(np.array(city_available_target)[avail_target_cities]) + # sum_target = sum(np.array(city_available_target)[avail_target_cities]) p_avail_start = [float(i) / sum_start for i in np.array(city_available_start)[avail_start_cities]] start_target_tuple = np.random.choice(avail_start_cities, p=p_avail_start, size=2, replace=False) diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index f3594057f607497db7937a589719b101c8fe19cc..ef05733db7564dc6a9bfb132cf42c047b9928b7a 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -1,5 +1,4 @@ """Schedule generators (railway undertaking, "EVU").""" -import random import warnings from typing import Tuple, List, Callable, Mapping, Optional, Any @@ -15,7 +14,8 @@ ScheduleGeneratorProduct = Tuple[List[AgentPosition], List[AgentPosition], List[ ScheduleGenerator = Callable[[GridTransitionMap, int, Optional[Any]], ScheduleGeneratorProduct] -def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, float] = None) -> List[float]: +def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, float] = None, + seed: int = None) -> List[float]: """ Parameters ---------- @@ -29,6 +29,9 @@ def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, List[float] A list of size nb_agents of speeds with the corresponding probabilistic ratios. """ + if seed: + np.random.seed(seed) + if speed_ratio_map is None: return [1.0] * nb_agents @@ -39,8 +42,12 @@ def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, return list(map(lambda index: speeds[index], np.random.choice(nb_classes, nb_agents, p=speed_ratios))) -def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator: - def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None): +def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None, seed: int = 1) -> ScheduleGenerator: + def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0): + + _runtime_seed = seed + num_resets + np.random.seed(_runtime_seed) + start_goal = hints['start_goal'] start_dir = hints['start_dir'] agents_position = [sg[0] for sg in start_goal[:num_agents]] @@ -48,7 +55,7 @@ def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> agents_direction = start_dir[:num_agents] if speed_ratio_map: - speeds = speed_initialization_helper(num_agents, speed_ratio_map) + speeds = speed_initialization_helper(num_agents, speed_ratio_map, seed=_runtime_seed) else: speeds = [1.0] * len(agents_position) @@ -57,13 +64,17 @@ def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> return generator -def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator: +def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, seed: int = 1) -> ScheduleGenerator: + + def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0): + + _runtime_seed = seed + num_resets + np.random.seed(_runtime_seed) - def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None): train_stations = hints['train_stations'] agent_start_targets_cities = hints['agent_start_targets_cities'] max_num_agents = hints['num_agents'] - city_orientations = hints['city_orientations'] + # city_orientations = hints['city_orientations'] if num_agents > max_num_agents: num_agents = max_num_agents warnings.warn("Too many agents! Changes number of agents.") @@ -76,13 +87,18 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> # Set target for agent start_city = agent_start_targets_cities[agent_idx][0] target_city = agent_start_targets_cities[agent_idx][1] - start = random.choice(train_stations[start_city]) - target = random.choice(train_stations[target_city]) + + start_idx = np.random.choice(np.arange(len(train_stations[start_city]))) + target_idx = np.random.choice(np.arange(len(train_stations[target_city]))) + start = train_stations[start_city][start_idx] + target = train_stations[target_city][target_idx] while start[1] % 2 != 0: - start = random.choice(train_stations[start_city]) + start_idx = np.random.choice(np.arange(len(train_stations[start_city]))) + start = train_stations[start_city][start_idx] while target[1] % 2 != 1: - target = random.choice(train_stations[target_city]) + target_idx = np.random.choice(np.arange(len(train_stations[target_city]))) + target = train_stations[target_city][target_idx] agent_orientation = (agent_start_targets_cities[agent_idx][2] + 2 * start[1]) % 4 if not rail.check_path_exists(start[0], agent_orientation, target[0]): agent_orientation = (agent_orientation + 2) % 4 @@ -96,7 +112,7 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> # Orient the agent correctly if speed_ratio_map: - speeds = speed_initialization_helper(num_agents, speed_ratio_map) + speeds = speed_initialization_helper(num_agents, speed_ratio_map, seed=_runtime_seed) else: speeds = [1.0] * len(agents_position) @@ -105,7 +121,8 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> return generator -def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] = None) -> ScheduleGenerator: +def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] = None, + seed: int = 1) -> ScheduleGenerator: """ Given a `rail` GridTransitionMap, return a random placement of agents (initial position, direction and target). @@ -120,7 +137,11 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] = initial positions, directions, targets speeds """ - def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct: + def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, + num_resets: int = 0) -> ScheduleGeneratorProduct: + _runtime_seed = seed + num_resets + + np.random.seed(_runtime_seed) valid_positions = [] for r in range(rail.height): @@ -189,7 +210,7 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] = agents_direction[i] = valid_starting_directions[ np.random.choice(len(valid_starting_directions), 1)[0]] - agents_speed = speed_initialization_helper(num_agents, speed_ratio_map) + agents_speed = speed_initialization_helper(num_agents, speed_ratio_map, seed=_runtime_seed) return agents_position, agents_direction, agents_target, agents_speed, None return generator @@ -209,7 +230,8 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator: initial positions, directions, targets speeds """ - def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct: + def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, + num_resets: int = 0) -> ScheduleGeneratorProduct: if load_from_package is not None: from importlib_resources import read_binary load_data = read_binary(load_from_package, filename) diff --git a/flatland/evaluators/client.py b/flatland/evaluators/client.py index 92f1a49d777552b2108aff3963aa5c1bc84fdbfc..d9c9ae9915da6521c39733cce649e2d64f41f1e6 100644 --- a/flatland/evaluators/client.py +++ b/flatland/evaluators/client.py @@ -166,6 +166,8 @@ class FlatlandRemoteClient(object): _request['payload'] = {} _response = self._blocking_request(_request) observation = _response['payload']['observation'] + info = _response['payload']['info'] + random_seed = _response['payload']['random_seed'] if not observation: # If the observation is False, @@ -196,10 +198,18 @@ class FlatlandRemoteClient(object): self.env._max_episode_steps = \ int(1.5 * (self.env.width + self.env.height)) - local_observation = self.env.reset() + local_observation = self.env.reset(random_seed=random_seed) + + local_observation, info = self.env.reset( + regen_rail=False, + replace_agents=False, + activate_agents=False, + random_seed=random_seed + ) + # Use the local observation # as the remote server uses a dummy observation builder - return local_observation + return local_observation, info def env_step(self, action, render=False): """ diff --git a/flatland/evaluators/service.py b/flatland/evaluators/service.py index 4f273be466a1c2f95b55cafece4783bad91e0d2c..023730dce5ddce411b7aa9951e9e8aa8f2f04f8a 100644 --- a/flatland/evaluators/service.py +++ b/flatland/evaluators/service.py @@ -332,14 +332,22 @@ class FlatlandRemoteEvaluationService: self.simulation_steps.append(0) self.current_step = 0 - - _observation = self.env.reset() + + RANDOM_SEED = 1001 + _observation, _info = self.env.reset( + regen_rail=False, + replace_agents=False, + activate_agents=False, + random_seed=RANDOM_SEED + ) _command_response = {} _command_response['type'] = messages.FLATLAND_RL.ENV_CREATE_RESPONSE _command_response['payload'] = {} _command_response['payload']['observation'] = _observation _command_response['payload']['env_file_path'] = self.env_file_paths[self.simulation_count] + _command_response['payload']['info'] = _info + _command_response['payload']['random_seed'] = RANDOM_SEED else: """ All test env evaluations are complete @@ -349,6 +357,8 @@ class FlatlandRemoteEvaluationService: _command_response['payload'] = {} _command_response['payload']['observation'] = False _command_response['payload']['env_file_path'] = False + _command_response['payload']['info'] = False + _command_response['payload']['random_seed'] = RANDOM_SEED self.send_response(_command_response, command) ##################################################################### diff --git a/requirements_dev.txt b/requirements_dev.txt index e02da21600b9cf0574457e915c69474afe0eff35..b71ad94971c59c7a2ef18f038f9c498330a92b1a 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -20,3 +20,4 @@ six>=1.12.0 timeout-decorator>=0.4.1 attrs ushlex +gym==0.14.0 diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index 131afa7ec850699b21fc4e2c74b3f54f8a851c63..7bd6334b877478c19205e56431dd95111a97caee 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -536,10 +536,10 @@ def test_sparse_rail_generator_deterministic(): stochastic_data=stochastic_data, # Malfunction data generator ) # for r in range(env.height): - # for c in range(env.width): - # print("assert env.rail.get_full_transitions({}, {}) == {}, \"[{}][{}]\"".format(r, c, - # env.rail.get_full_transitions( - # r, c), r, c)) + # for c in range(env.width): + # print("assert env.rail.get_full_transitions({}, {}) == {}, \"[{}][{}]\"".format(r, c, + # env.rail.get_full_transitions( + # r, c), r, c)) assert env.rail.get_full_transitions(0, 1) == 0, "[0][1]" assert env.rail.get_full_transitions(0, 2) == 0, "[0][2]" @@ -676,11 +676,11 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(5, 8) == 0, "[5][8]" assert env.rail.get_full_transitions(5, 9) == 0, "[5][9]" assert env.rail.get_full_transitions(5, 10) == 0, "[5][10]" - assert env.rail.get_full_transitions(5, 11) == 16386, "[5][11]" - assert env.rail.get_full_transitions(5, 12) == 1025, "[5][12]" - assert env.rail.get_full_transitions(5, 13) == 1025, "[5][13]" - assert env.rail.get_full_transitions(5, 14) == 1025, "[5][14]" - assert env.rail.get_full_transitions(5, 15) == 4608, "[5][15]" + assert env.rail.get_full_transitions(5, 11) == 0, "[5][11]" + assert env.rail.get_full_transitions(5, 12) == 0, "[5][12]" + assert env.rail.get_full_transitions(5, 13) == 0, "[5][13]" + assert env.rail.get_full_transitions(5, 14) == 0, "[5][14]" + assert env.rail.get_full_transitions(5, 15) == 0, "[5][15]" assert env.rail.get_full_transitions(5, 16) == 0, "[5][16]" assert env.rail.get_full_transitions(5, 17) == 0, "[5][17]" assert env.rail.get_full_transitions(5, 18) == 0, "[5][18]" @@ -700,16 +700,16 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(6, 7) == 1025, "[6][7]" assert env.rail.get_full_transitions(6, 8) == 1025, "[6][8]" assert env.rail.get_full_transitions(6, 9) == 5633, "[6][9]" - assert env.rail.get_full_transitions(6, 10) == 5633, "[6][10]" - assert env.rail.get_full_transitions(6, 11) == 3089, "[6][11]" - assert env.rail.get_full_transitions(6, 12) == 1025, "[6][12]" - assert env.rail.get_full_transitions(6, 13) == 1025, "[6][13]" - assert env.rail.get_full_transitions(6, 14) == 1025, "[6][14]" - assert env.rail.get_full_transitions(6, 15) == 1097, "[6][15]" - assert env.rail.get_full_transitions(6, 16) == 5633, "[6][16]" - assert env.rail.get_full_transitions(6, 17) == 17411, "[6][17]" - assert env.rail.get_full_transitions(6, 18) == 1025, "[6][18]" - assert env.rail.get_full_transitions(6, 19) == 4608, "[6][19]" + assert env.rail.get_full_transitions(6, 10) == 17411, "[6][10]" + assert env.rail.get_full_transitions(6, 11) == 1025, "[6][11]" + assert env.rail.get_full_transitions(6, 12) == 4608, "[6][12]" + assert env.rail.get_full_transitions(6, 13) == 0, "[6][13]" + assert env.rail.get_full_transitions(6, 14) == 0, "[6][14]" + assert env.rail.get_full_transitions(6, 15) == 0, "[6][15]" + assert env.rail.get_full_transitions(6, 16) == 0, "[6][16]" + assert env.rail.get_full_transitions(6, 17) == 0, "[6][17]" + assert env.rail.get_full_transitions(6, 18) == 0, "[6][18]" + assert env.rail.get_full_transitions(6, 19) == 0, "[6][19]" assert env.rail.get_full_transitions(6, 20) == 0, "[6][20]" assert env.rail.get_full_transitions(6, 21) == 0, "[6][21]" assert env.rail.get_full_transitions(6, 22) == 0, "[6][22]" @@ -724,17 +724,17 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(7, 6) == 1025, "[7][6]" assert env.rail.get_full_transitions(7, 7) == 1025, "[7][7]" assert env.rail.get_full_transitions(7, 8) == 1025, "[7][8]" - assert env.rail.get_full_transitions(7, 9) == 3089, "[7][9]" - assert env.rail.get_full_transitions(7, 10) == 3089, "[7][10]" - assert env.rail.get_full_transitions(7, 11) == 5633, "[7][11]" - assert env.rail.get_full_transitions(7, 12) == 1025, "[7][12]" - assert env.rail.get_full_transitions(7, 13) == 1025, "[7][13]" - assert env.rail.get_full_transitions(7, 14) == 1025, "[7][14]" - assert env.rail.get_full_transitions(7, 15) == 17411, "[7][15]" - assert env.rail.get_full_transitions(7, 16) == 1097, "[7][16]" - assert env.rail.get_full_transitions(7, 17) == 2064, "[7][17]" + assert env.rail.get_full_transitions(7, 9) == 1097, "[7][9]" + assert env.rail.get_full_transitions(7, 10) == 2064, "[7][10]" + assert env.rail.get_full_transitions(7, 11) == 0, "[7][11]" + assert env.rail.get_full_transitions(7, 12) == 32800, "[7][12]" + assert env.rail.get_full_transitions(7, 13) == 0, "[7][13]" + assert env.rail.get_full_transitions(7, 14) == 0, "[7][14]" + assert env.rail.get_full_transitions(7, 15) == 0, "[7][15]" + assert env.rail.get_full_transitions(7, 16) == 0, "[7][16]" + assert env.rail.get_full_transitions(7, 17) == 0, "[7][17]" assert env.rail.get_full_transitions(7, 18) == 0, "[7][18]" - assert env.rail.get_full_transitions(7, 19) == 32800, "[7][19]" + assert env.rail.get_full_transitions(7, 19) == 0, "[7][19]" assert env.rail.get_full_transitions(7, 20) == 0, "[7][20]" assert env.rail.get_full_transitions(7, 21) == 0, "[7][21]" assert env.rail.get_full_transitions(7, 22) == 0, "[7][22]" @@ -751,15 +751,15 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(8, 8) == 0, "[8][8]" assert env.rail.get_full_transitions(8, 9) == 0, "[8][9]" assert env.rail.get_full_transitions(8, 10) == 0, "[8][10]" - assert env.rail.get_full_transitions(8, 11) == 72, "[8][11]" - assert env.rail.get_full_transitions(8, 12) == 1025, "[8][12]" - assert env.rail.get_full_transitions(8, 13) == 1025, "[8][13]" - assert env.rail.get_full_transitions(8, 14) == 1025, "[8][14]" - assert env.rail.get_full_transitions(8, 15) == 2064, "[8][15]" + assert env.rail.get_full_transitions(8, 11) == 0, "[8][11]" + assert env.rail.get_full_transitions(8, 12) == 32800, "[8][12]" + assert env.rail.get_full_transitions(8, 13) == 0, "[8][13]" + assert env.rail.get_full_transitions(8, 14) == 0, "[8][14]" + assert env.rail.get_full_transitions(8, 15) == 0, "[8][15]" assert env.rail.get_full_transitions(8, 16) == 0, "[8][16]" assert env.rail.get_full_transitions(8, 17) == 0, "[8][17]" assert env.rail.get_full_transitions(8, 18) == 0, "[8][18]" - assert env.rail.get_full_transitions(8, 19) == 32800, "[8][19]" + assert env.rail.get_full_transitions(8, 19) == 0, "[8][19]" assert env.rail.get_full_transitions(8, 20) == 0, "[8][20]" assert env.rail.get_full_transitions(8, 21) == 0, "[8][21]" assert env.rail.get_full_transitions(8, 22) == 0, "[8][22]" @@ -777,14 +777,14 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(9, 9) == 0, "[9][9]" assert env.rail.get_full_transitions(9, 10) == 0, "[9][10]" assert env.rail.get_full_transitions(9, 11) == 0, "[9][11]" - assert env.rail.get_full_transitions(9, 12) == 0, "[9][12]" + assert env.rail.get_full_transitions(9, 12) == 32800, "[9][12]" assert env.rail.get_full_transitions(9, 13) == 0, "[9][13]" assert env.rail.get_full_transitions(9, 14) == 0, "[9][14]" assert env.rail.get_full_transitions(9, 15) == 0, "[9][15]" assert env.rail.get_full_transitions(9, 16) == 0, "[9][16]" assert env.rail.get_full_transitions(9, 17) == 0, "[9][17]" assert env.rail.get_full_transitions(9, 18) == 0, "[9][18]" - assert env.rail.get_full_transitions(9, 19) == 32800, "[9][19]" + assert env.rail.get_full_transitions(9, 19) == 0, "[9][19]" assert env.rail.get_full_transitions(9, 20) == 0, "[9][20]" assert env.rail.get_full_transitions(9, 21) == 0, "[9][21]" assert env.rail.get_full_transitions(9, 22) == 0, "[9][22]" @@ -802,14 +802,14 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(10, 9) == 0, "[10][9]" assert env.rail.get_full_transitions(10, 10) == 0, "[10][10]" assert env.rail.get_full_transitions(10, 11) == 0, "[10][11]" - assert env.rail.get_full_transitions(10, 12) == 0, "[10][12]" + assert env.rail.get_full_transitions(10, 12) == 32800, "[10][12]" assert env.rail.get_full_transitions(10, 13) == 0, "[10][13]" assert env.rail.get_full_transitions(10, 14) == 0, "[10][14]" assert env.rail.get_full_transitions(10, 15) == 0, "[10][15]" assert env.rail.get_full_transitions(10, 16) == 0, "[10][16]" assert env.rail.get_full_transitions(10, 17) == 0, "[10][17]" assert env.rail.get_full_transitions(10, 18) == 0, "[10][18]" - assert env.rail.get_full_transitions(10, 19) == 32800, "[10][19]" + assert env.rail.get_full_transitions(10, 19) == 0, "[10][19]" assert env.rail.get_full_transitions(10, 20) == 0, "[10][20]" assert env.rail.get_full_transitions(10, 21) == 0, "[10][21]" assert env.rail.get_full_transitions(10, 22) == 0, "[10][22]" @@ -827,14 +827,14 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(11, 9) == 0, "[11][9]" assert env.rail.get_full_transitions(11, 10) == 0, "[11][10]" assert env.rail.get_full_transitions(11, 11) == 0, "[11][11]" - assert env.rail.get_full_transitions(11, 12) == 0, "[11][12]" + assert env.rail.get_full_transitions(11, 12) == 32800, "[11][12]" assert env.rail.get_full_transitions(11, 13) == 0, "[11][13]" assert env.rail.get_full_transitions(11, 14) == 0, "[11][14]" assert env.rail.get_full_transitions(11, 15) == 0, "[11][15]" assert env.rail.get_full_transitions(11, 16) == 0, "[11][16]" assert env.rail.get_full_transitions(11, 17) == 0, "[11][17]" assert env.rail.get_full_transitions(11, 18) == 0, "[11][18]" - assert env.rail.get_full_transitions(11, 19) == 32800, "[11][19]" + assert env.rail.get_full_transitions(11, 19) == 0, "[11][19]" assert env.rail.get_full_transitions(11, 20) == 0, "[11][20]" assert env.rail.get_full_transitions(11, 21) == 0, "[11][21]" assert env.rail.get_full_transitions(11, 22) == 0, "[11][22]" @@ -852,14 +852,14 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(12, 9) == 0, "[12][9]" assert env.rail.get_full_transitions(12, 10) == 0, "[12][10]" assert env.rail.get_full_transitions(12, 11) == 0, "[12][11]" - assert env.rail.get_full_transitions(12, 12) == 0, "[12][12]" + assert env.rail.get_full_transitions(12, 12) == 32800, "[12][12]" assert env.rail.get_full_transitions(12, 13) == 0, "[12][13]" assert env.rail.get_full_transitions(12, 14) == 0, "[12][14]" assert env.rail.get_full_transitions(12, 15) == 0, "[12][15]" assert env.rail.get_full_transitions(12, 16) == 0, "[12][16]" assert env.rail.get_full_transitions(12, 17) == 0, "[12][17]" assert env.rail.get_full_transitions(12, 18) == 0, "[12][18]" - assert env.rail.get_full_transitions(12, 19) == 32800, "[12][19]" + assert env.rail.get_full_transitions(12, 19) == 0, "[12][19]" assert env.rail.get_full_transitions(12, 20) == 0, "[12][20]" assert env.rail.get_full_transitions(12, 21) == 0, "[12][21]" assert env.rail.get_full_transitions(12, 22) == 0, "[12][22]" @@ -877,14 +877,14 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(13, 9) == 0, "[13][9]" assert env.rail.get_full_transitions(13, 10) == 0, "[13][10]" assert env.rail.get_full_transitions(13, 11) == 0, "[13][11]" - assert env.rail.get_full_transitions(13, 12) == 0, "[13][12]" + assert env.rail.get_full_transitions(13, 12) == 32800, "[13][12]" assert env.rail.get_full_transitions(13, 13) == 0, "[13][13]" assert env.rail.get_full_transitions(13, 14) == 0, "[13][14]" assert env.rail.get_full_transitions(13, 15) == 0, "[13][15]" assert env.rail.get_full_transitions(13, 16) == 0, "[13][16]" assert env.rail.get_full_transitions(13, 17) == 0, "[13][17]" assert env.rail.get_full_transitions(13, 18) == 0, "[13][18]" - assert env.rail.get_full_transitions(13, 19) == 32800, "[13][19]" + assert env.rail.get_full_transitions(13, 19) == 0, "[13][19]" assert env.rail.get_full_transitions(13, 20) == 0, "[13][20]" assert env.rail.get_full_transitions(13, 21) == 0, "[13][21]" assert env.rail.get_full_transitions(13, 22) == 0, "[13][22]" @@ -902,14 +902,14 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(14, 9) == 0, "[14][9]" assert env.rail.get_full_transitions(14, 10) == 0, "[14][10]" assert env.rail.get_full_transitions(14, 11) == 0, "[14][11]" - assert env.rail.get_full_transitions(14, 12) == 0, "[14][12]" + assert env.rail.get_full_transitions(14, 12) == 32800, "[14][12]" assert env.rail.get_full_transitions(14, 13) == 0, "[14][13]" assert env.rail.get_full_transitions(14, 14) == 0, "[14][14]" assert env.rail.get_full_transitions(14, 15) == 0, "[14][15]" assert env.rail.get_full_transitions(14, 16) == 0, "[14][16]" assert env.rail.get_full_transitions(14, 17) == 0, "[14][17]" assert env.rail.get_full_transitions(14, 18) == 0, "[14][18]" - assert env.rail.get_full_transitions(14, 19) == 32800, "[14][19]" + assert env.rail.get_full_transitions(14, 19) == 0, "[14][19]" assert env.rail.get_full_transitions(14, 20) == 0, "[14][20]" assert env.rail.get_full_transitions(14, 21) == 0, "[14][21]" assert env.rail.get_full_transitions(14, 22) == 0, "[14][22]" @@ -927,14 +927,14 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(15, 9) == 0, "[15][9]" assert env.rail.get_full_transitions(15, 10) == 0, "[15][10]" assert env.rail.get_full_transitions(15, 11) == 0, "[15][11]" - assert env.rail.get_full_transitions(15, 12) == 0, "[15][12]" + assert env.rail.get_full_transitions(15, 12) == 32800, "[15][12]" assert env.rail.get_full_transitions(15, 13) == 0, "[15][13]" assert env.rail.get_full_transitions(15, 14) == 0, "[15][14]" assert env.rail.get_full_transitions(15, 15) == 0, "[15][15]" assert env.rail.get_full_transitions(15, 16) == 0, "[15][16]" assert env.rail.get_full_transitions(15, 17) == 0, "[15][17]" assert env.rail.get_full_transitions(15, 18) == 0, "[15][18]" - assert env.rail.get_full_transitions(15, 19) == 32800, "[15][19]" + assert env.rail.get_full_transitions(15, 19) == 0, "[15][19]" assert env.rail.get_full_transitions(15, 20) == 0, "[15][20]" assert env.rail.get_full_transitions(15, 21) == 0, "[15][21]" assert env.rail.get_full_transitions(15, 22) == 0, "[15][22]" @@ -952,14 +952,14 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(16, 9) == 0, "[16][9]" assert env.rail.get_full_transitions(16, 10) == 0, "[16][10]" assert env.rail.get_full_transitions(16, 11) == 0, "[16][11]" - assert env.rail.get_full_transitions(16, 12) == 0, "[16][12]" + assert env.rail.get_full_transitions(16, 12) == 32800, "[16][12]" assert env.rail.get_full_transitions(16, 13) == 0, "[16][13]" assert env.rail.get_full_transitions(16, 14) == 0, "[16][14]" assert env.rail.get_full_transitions(16, 15) == 0, "[16][15]" assert env.rail.get_full_transitions(16, 16) == 0, "[16][16]" assert env.rail.get_full_transitions(16, 17) == 0, "[16][17]" assert env.rail.get_full_transitions(16, 18) == 0, "[16][18]" - assert env.rail.get_full_transitions(16, 19) == 32800, "[16][19]" + assert env.rail.get_full_transitions(16, 19) == 0, "[16][19]" assert env.rail.get_full_transitions(16, 20) == 0, "[16][20]" assert env.rail.get_full_transitions(16, 21) == 0, "[16][21]" assert env.rail.get_full_transitions(16, 22) == 0, "[16][22]" @@ -977,14 +977,14 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(17, 9) == 0, "[17][9]" assert env.rail.get_full_transitions(17, 10) == 0, "[17][10]" assert env.rail.get_full_transitions(17, 11) == 0, "[17][11]" - assert env.rail.get_full_transitions(17, 12) == 0, "[17][12]" + assert env.rail.get_full_transitions(17, 12) == 32800, "[17][12]" assert env.rail.get_full_transitions(17, 13) == 0, "[17][13]" assert env.rail.get_full_transitions(17, 14) == 0, "[17][14]" assert env.rail.get_full_transitions(17, 15) == 0, "[17][15]" assert env.rail.get_full_transitions(17, 16) == 0, "[17][16]" assert env.rail.get_full_transitions(17, 17) == 0, "[17][17]" assert env.rail.get_full_transitions(17, 18) == 0, "[17][18]" - assert env.rail.get_full_transitions(17, 19) == 32800, "[17][19]" + assert env.rail.get_full_transitions(17, 19) == 0, "[17][19]" assert env.rail.get_full_transitions(17, 20) == 0, "[17][20]" assert env.rail.get_full_transitions(17, 21) == 0, "[17][21]" assert env.rail.get_full_transitions(17, 22) == 0, "[17][22]" @@ -1000,16 +1000,16 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(18, 7) == 1025, "[18][7]" assert env.rail.get_full_transitions(18, 8) == 1025, "[18][8]" assert env.rail.get_full_transitions(18, 9) == 5633, "[18][9]" - assert env.rail.get_full_transitions(18, 10) == 5633, "[18][10]" + assert env.rail.get_full_transitions(18, 10) == 17411, "[18][10]" assert env.rail.get_full_transitions(18, 11) == 1025, "[18][11]" - assert env.rail.get_full_transitions(18, 12) == 1025, "[18][12]" - assert env.rail.get_full_transitions(18, 13) == 1025, "[18][13]" - assert env.rail.get_full_transitions(18, 14) == 1025, "[18][14]" - assert env.rail.get_full_transitions(18, 15) == 1025, "[18][15]" - assert env.rail.get_full_transitions(18, 16) == 5633, "[18][16]" - assert env.rail.get_full_transitions(18, 17) == 17411, "[18][17]" - assert env.rail.get_full_transitions(18, 18) == 1025, "[18][18]" - assert env.rail.get_full_transitions(18, 19) == 34864, "[18][19]" + assert env.rail.get_full_transitions(18, 12) == 2064, "[18][12]" + assert env.rail.get_full_transitions(18, 13) == 0, "[18][13]" + assert env.rail.get_full_transitions(18, 14) == 0, "[18][14]" + assert env.rail.get_full_transitions(18, 15) == 0, "[18][15]" + assert env.rail.get_full_transitions(18, 16) == 0, "[18][16]" + assert env.rail.get_full_transitions(18, 17) == 0, "[18][17]" + assert env.rail.get_full_transitions(18, 18) == 0, "[18][18]" + assert env.rail.get_full_transitions(18, 19) == 0, "[18][19]" assert env.rail.get_full_transitions(18, 20) == 0, "[18][20]" assert env.rail.get_full_transitions(18, 21) == 0, "[18][21]" assert env.rail.get_full_transitions(18, 22) == 0, "[18][22]" @@ -1024,17 +1024,17 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(19, 6) == 1025, "[19][6]" assert env.rail.get_full_transitions(19, 7) == 1025, "[19][7]" assert env.rail.get_full_transitions(19, 8) == 1025, "[19][8]" - assert env.rail.get_full_transitions(19, 9) == 3089, "[19][9]" - assert env.rail.get_full_transitions(19, 10) == 3089, "[19][10]" - assert env.rail.get_full_transitions(19, 11) == 1025, "[19][11]" - assert env.rail.get_full_transitions(19, 12) == 1025, "[19][12]" - assert env.rail.get_full_transitions(19, 13) == 1025, "[19][13]" - assert env.rail.get_full_transitions(19, 14) == 1025, "[19][14]" - assert env.rail.get_full_transitions(19, 15) == 1025, "[19][15]" - assert env.rail.get_full_transitions(19, 16) == 1097, "[19][16]" - assert env.rail.get_full_transitions(19, 17) == 3089, "[19][17]" - assert env.rail.get_full_transitions(19, 18) == 1025, "[19][18]" - assert env.rail.get_full_transitions(19, 19) == 2064, "[19][19]" + assert env.rail.get_full_transitions(19, 9) == 1097, "[19][9]" + assert env.rail.get_full_transitions(19, 10) == 2064, "[19][10]" + assert env.rail.get_full_transitions(19, 11) == 0, "[19][11]" + assert env.rail.get_full_transitions(19, 12) == 0, "[19][12]" + assert env.rail.get_full_transitions(19, 13) == 0, "[19][13]" + assert env.rail.get_full_transitions(19, 14) == 0, "[19][14]" + assert env.rail.get_full_transitions(19, 15) == 0, "[19][15]" + assert env.rail.get_full_transitions(19, 16) == 0, "[19][16]" + assert env.rail.get_full_transitions(19, 17) == 0, "[19][17]" + assert env.rail.get_full_transitions(19, 18) == 0, "[19][18]" + assert env.rail.get_full_transitions(19, 19) == 0, "[19][19]" assert env.rail.get_full_transitions(19, 20) == 0, "[19][20]" assert env.rail.get_full_transitions(19, 21) == 0, "[19][21]" assert env.rail.get_full_transitions(19, 22) == 0, "[19][22]" @@ -1290,6 +1290,8 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(29, 22) == 0, "[29][22]" assert env.rail.get_full_transitions(29, 23) == 0, "[29][23]" assert env.rail.get_full_transitions(29, 24) == 0, "[29][24]" + + def test_rail_env_action_required_info(): np.random.seed(0) random.seed(0) diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index d5fce5d6c8fa4cf37b4c72f6264726f6d502f77d..73c831a426d55abf45217d66cba57481f6cc63ea 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -70,14 +70,6 @@ def test_malfunction_process(): 'malfunction_rate': 1000, 'min_duration': 3, 'max_duration': 3} - random.seed(0) - np.random.seed(0) - - stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents - 'malfunction_rate': 70, # Rate of malfunction occurence - 'min_duration': 2, # Minimal duration of malfunction - 'max_duration': 5 # Max duration of malfunction - } rail, rail_map = make_simple_rail2() @@ -90,9 +82,7 @@ def test_malfunction_process(): obs_builder_object=SingleAgentNavigationObs() ) # reset to initialize agents_static - env.reset() - - obs = env.reset(False, False, True) + obs, info = env.reset(False, False, True, random_seed=10) # Check that a initial duration for malfunction was assigned assert env.agents[0].malfunction_data['next_malfunction'] > 0 @@ -102,6 +92,9 @@ def test_malfunction_process(): agent_halts = 0 total_down_time = 0 agent_old_position = env.agents[0].position + + # Move target to unreachable position in order to not interfere with test + env.agents[0].target = (0, 0) for step in range(100): actions = {} @@ -147,10 +140,6 @@ def test_malfunction_process_statistically(): 'min_duration': 3, 'max_duration': 3} - random.seed(0) - np.random.seed(0) - - rail, rail_map = make_simple_rail2() env = RailEnv(width=25, @@ -162,7 +151,8 @@ def test_malfunction_process_statistically(): obs_builder_object=SingleAgentNavigationObs() ) # reset to initialize agents_static - env.reset(False, False, False) + env.reset(True, True, False, random_seed=10) + env.agents[0].target = (0, 0) nb_malfunction = 0 for step in range(20): @@ -172,56 +162,69 @@ def test_malfunction_process_statistically(): action_dict[agent.handle] = RailEnvActions(np.random.randint(4)) env.step(action_dict) - # check that generation of malfunctions works as expected - assert env.agents[0].malfunction_data["nr_malfunctions"] == 4 + assert env.agents[0].malfunction_data["nr_malfunctions"] == 5 def test_malfunction_before_entry(): - """Tests hat malfunctions are produced by stochastic_data!""" + """Tests that malfunctions are produced by stochastic_data!""" # Set fixed malfunction duration for this test stochastic_data = {'prop_malfunction': 1., 'malfunction_rate': 2, 'min_duration': 10, 'max_duration': 10} - random.seed(0) - np.random.seed(0) - rail, rail_map = make_simple_rail2() env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=1, + schedule_generator=random_schedule_generator(seed=2), # seed 12 + number_of_agents=10, + random_seed=1, stochastic_data=stochastic_data, # Malfunction data generator - obs_builder_object=SingleAgentNavigationObs() ) # reset to initialize agents_static - env.reset(False, False, False) + env.reset(False, False, False, random_seed=10) env.agents[0].target = (0, 0) - nb_malfunction = 0 + assert env.agents[1].malfunction_data['malfunction'] == 11 + assert env.agents[2].malfunction_data['malfunction'] == 11 + assert env.agents[3].malfunction_data['malfunction'] == 11 + assert env.agents[4].malfunction_data['malfunction'] == 11 + assert env.agents[5].malfunction_data['malfunction'] == 0 + assert env.agents[6].malfunction_data['malfunction'] == 11 + assert env.agents[7].malfunction_data['malfunction'] == 11 + assert env.agents[8].malfunction_data['malfunction'] == 11 + assert env.agents[9].malfunction_data['malfunction'] == 0 + for step in range(20): action_dict: Dict[int, RailEnvActions] = {} for agent in env.agents: # We randomly select an action + action_dict[agent.handle] = RailEnvActions(2) if step < 10: action_dict[agent.handle] = RailEnvActions(0) - assert env.agents[0].malfunction_data['malfunction'] == 0 - else: - action_dict[agent.handle] = RailEnvActions(2) - print(env.agents[0].malfunction_data) env.step(action_dict) - assert env.agents[0].malfunction_data['malfunction'] > 0 + assert env.agents[1].malfunction_data['malfunction'] == 1 + assert env.agents[2].malfunction_data['malfunction'] == 1 + assert env.agents[3].malfunction_data['malfunction'] == 1 + assert env.agents[4].malfunction_data['malfunction'] == 1 + assert env.agents[5].malfunction_data['malfunction'] == 2 + assert env.agents[6].malfunction_data['malfunction'] == 1 + assert env.agents[7].malfunction_data['malfunction'] == 1 + assert env.agents[8].malfunction_data['malfunction'] == 1 + assert env.agents[9].malfunction_data['malfunction'] == 3 + + # Print for test generation + # for a in range(env.get_num_agents()): + # print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a, + # env.agents[a].malfunction_data[ + # 'malfunction'])) def test_initial_malfunction(): - random.seed(0) - np.random.seed(0) - stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents 'malfunction_rate': 70, # Rate of malfunction occurence 'min_duration': 2, # Minimal duration of malfunction @@ -240,8 +243,8 @@ def test_initial_malfunction(): ) # reset to initialize agents_static - env.reset(False, False, True) - + env.reset(False, False, True, random_seed=10) + env.agents[0].target = (0, 5) set_penalties_for_replay(env) replay_config = ReplayConfig( replay=[ @@ -294,9 +297,6 @@ def test_initial_malfunction(): def test_initial_malfunction_stop_moving(): - random.seed(0) - np.random.seed(0) - stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents 'malfunction_rate': 70, # Rate of malfunction occurence 'min_duration': 2, # Minimal duration of malfunction diff --git a/tests/test_random_seeding.py b/tests/test_random_seeding.py new file mode 100644 index 0000000000000000000000000000000000000000..3a2fe8bc1775653e15842ac27c262aab3e2b9ded --- /dev/null +++ b/tests/test_random_seeding.py @@ -0,0 +1,35 @@ +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.schedule_generators import random_schedule_generator +from flatland.utils.simple_rail import make_simple_rail2 + + +def test_random_seeding(): + # Set fixed malfunction duration for this test + rail, rail_map = make_simple_rail2() + + # Move target to unreachable position in order to not interfere with test + for idx in range(100): + env = RailEnv(width=25, + height=30, + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(seed=12), + number_of_agents=10 + ) + env.reset(True, True, False, random_seed=1) + # Test generation print + + env.agents[0].target = (0, 0) + for step in range(10): + actions = {} + actions[0] = 2 + env.step(actions) + agent_positions = [] + for a in range(env.get_num_agents()): + agent_positions += env.agents[a].initial_position + # print(agent_positions) + assert agent_positions == [3, 2, 3, 5, 3, 6, 5, 6, 3, 4, 3, 1, 3, 9, 4, 6, 0, 3, 3, 7] + # Test generation print + assert env.agents[0].position == (3, 6) + # print("env.agents[0].initial_position == {}".format(env.agents[0].initial_position)) + #print("assert env.agents[0].position == {}".format(env.agents[0].position))