From dece6c1673f53ce3cc40a8b2dceeaca1d7772e6f Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Tue, 27 Aug 2019 11:22:26 +0200 Subject: [PATCH] #141 different agent classes --- examples/complex_rail_benchmark.py | 2 + examples/custom_observation_example.py | 9 +- examples/debugging_example_DELETE.py | 11 +- examples/simple_example_3.py | 2 + examples/training_example.py | 3 + flatland/cli.py | 46 +++--- flatland/envs/agent_generators.py | 182 +++++++++++++++++++++++ flatland/envs/generators.py | 94 +++--------- flatland/envs/grid4_generators_utils.py | 82 +--------- flatland/envs/rail_env.py | 37 +++-- tests/test_distance_map.py | 2 + tests/test_flatland_envs_observations.py | 4 + tests/test_flatland_envs_predictions.py | 4 + tests/test_flatland_envs_rail_env.py | 5 + tests/test_flatland_malfunction.py | 2 + tests/test_multi_speed.py | 3 + tests/test_speed_classes.py | 9 +- tests/tests_generators.py | 18 ++- 18 files changed, 311 insertions(+), 204 deletions(-) create mode 100644 flatland/envs/agent_generators.py diff --git a/examples/complex_rail_benchmark.py b/examples/complex_rail_benchmark.py index 44e4b534..624ad669 100644 --- a/examples/complex_rail_benchmark.py +++ b/examples/complex_rail_benchmark.py @@ -3,6 +3,7 @@ import random import numpy as np +from flatland.envs.agent_generators import complex_rail_generator_agents_placer from flatland.envs.generators import complex_rail_generator from flatland.envs.rail_env import RailEnv @@ -15,6 +16,7 @@ def run_benchmark(): # Example generate a random rail env = RailEnv(width=15, height=15, rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12), + agent_generator=complex_rail_generator_agents_placer(), number_of_agents=5) n_trials = 20 diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py index 723bb110..401ff94a 100644 --- a/examples/custom_observation_example.py +++ b/examples/custom_observation_example.py @@ -5,6 +5,7 @@ import numpy as np from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid_utils import coordinate_to_position +from flatland.envs.agent_generators import complex_rail_generator_agents_placer from flatland.envs.generators import random_rail_generator, complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv @@ -20,6 +21,7 @@ class SimpleObs(ObservationBuilder): Simplest observation builder. The object returns observation vectors with 5 identical components, all equal to the ID of the respective agent. """ + def __init__(self): self.observation_space = [5] @@ -53,6 +55,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector will be [1, 0, 0]. """ + def __init__(self): super().__init__(max_depth=0) self.observation_space = [3] @@ -90,6 +93,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): env = RailEnv(width=7, height=7, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0), + agent_generator=complex_rail_generator_agents_placer(), number_of_agents=1, obs_builder_object=SingleAgentNavigationObs()) @@ -97,8 +101,8 @@ obs = env.reset() env_renderer = RenderTool(env, gl="PILSVG") env_renderer.render_env(show=True, frames=True, show_observations=True) for step in range(100): - action = np.argmax(obs[0])+1 - obs, all_rewards, done, _ = env.step({0:action}) + action = np.argmax(obs[0]) + 1 + obs, all_rewards, done, _ = env.step({0: action}) print("Rewards: ", all_rewards, " [done=", done, "]") env_renderer.render_env(show=True, frames=True, show_observations=True) time.sleep(0.1) @@ -200,6 +204,7 @@ CustomObsBuilder = ObservePredictions(CustomPredictor) env = RailEnv(width=10, height=10, rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, seed=0), + agent_generator=complex_rail_generator_agents_placer(), number_of_agents=3, obs_builder_object=CustomObsBuilder) diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py index 2c0f8145..68fdc8ab 100644 --- a/examples/debugging_example_DELETE.py +++ b/examples/debugging_example_DELETE.py @@ -3,6 +3,7 @@ import time import numpy as np +from flatland.envs.agent_generators import complex_rail_generator_agents_placer from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv @@ -11,6 +12,7 @@ from flatland.utils.rendertools import RenderTool random.seed(1) np.random.seed(1) + class SingleAgentNavigationObs(TreeObsForRailEnv): """ We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute @@ -21,6 +23,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector will be [1, 0, 0]. """ + def __init__(self): super().__init__(max_depth=0) self.observation_space = [3] @@ -58,6 +61,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): env = RailEnv(width=14, height=14, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0), + agent_generator=complex_rail_generator_agents_placer(), number_of_agents=2, obs_builder_object=SingleAgentNavigationObs()) @@ -67,11 +71,11 @@ env_renderer.render_env(show=True, frames=True, show_observations=False) for step in range(100): actions = {} for i in range(len(obs)): - actions[i] = np.argmax(obs[i])+1 + actions[i] = np.argmax(obs[i]) + 1 - if step%5 == 0: + if step % 5 == 0: print("Agent halts") - actions[0] = 4 # Halt + actions[0] = 4 # Halt obs, all_rewards, done, _ = env.step(actions) if env.agents[0].malfunction_data['malfunction'] > 0: @@ -82,4 +86,3 @@ for step in range(100): if done["__all__"]: break env_renderer.close_window() - diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py index 5aa03d8f..1e20fcca 100644 --- a/examples/simple_example_3.py +++ b/examples/simple_example_3.py @@ -2,6 +2,7 @@ import random import numpy as np +from flatland.envs.agent_generators import complex_rail_generator_agents_placer from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv @@ -13,6 +14,7 @@ np.random.seed(1) env = RailEnv(width=7, height=7, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), + agent_generator=complex_rail_generator_agents_placer(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2)) diff --git a/examples/training_example.py b/examples/training_example.py index d125be15..f339d329 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -1,5 +1,6 @@ import numpy as np +from flatland.envs.agent_generators import complex_rail_generator_agents_placer from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv, LocalObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv @@ -16,11 +17,13 @@ LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2) env = RailEnv(width=20, height=20, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), + agent_generator=complex_rail_generator_agents_placer(), obs_builder_object=TreeObservation, number_of_agents=3) env_renderer = RenderTool(env, gl="PILSVG", ) + # Import your own Agent or use RLlib to train agents on Flatland # As an example we use a random agent here diff --git a/flatland/cli.py b/flatland/cli.py index 32e8d9dc..56b2feab 100644 --- a/flatland/cli.py +++ b/flatland/cli.py @@ -2,29 +2,33 @@ """Console script for flatland.""" import sys +import time + import click import numpy as np -import time +import redis + +from flatland.envs.agent_generators import complex_rail_generator_agents_placer from flatland.envs.generators import complex_rail_generator from flatland.envs.rail_env import RailEnv -from flatland.utils.rendertools import RenderTool from flatland.evaluators.service import FlatlandRemoteEvaluationService -import redis +from flatland.utils.rendertools import RenderTool @click.command() def demo(args=None): """Demo script to check installation""" env = RailEnv( - width=15, - height=15, - rail_generator=complex_rail_generator( - nr_start_goal=10, - nr_extra=1, - min_dist=8, - max_dist=99999), - number_of_agents=5) - + width=15, + height=15, + rail_generator=complex_rail_generator( + nr_start_goal=10, + nr_extra=1, + min_dist=8, + max_dist=99999), + agent_generator=complex_rail_generator_agents_placer(), + number_of_agents=5) + env._max_episode_steps = int(15 * (env.width + env.height)) env_renderer = RenderTool(env) @@ -52,12 +56,12 @@ def demo(args=None): @click.command() -@click.option('--tests', +@click.option('--tests', type=click.Path(exists=True), help="Path to folder containing Flatland tests", required=True ) -@click.option('--service_id', +@click.option('--service_id', default="FLATLAND_RL_SERVICE_ID", help="Evaluation Service ID. This has to match the service id on the client.", required=False @@ -70,14 +74,14 @@ def evaluator(tests, service_id): raise Exception( "\nRedis server does not seem to be running on your localhost.\n" "Please ensure that you have a redis server running on your localhost" - ) - + ) + grader = FlatlandRemoteEvaluationService( - test_env_folder=tests, - flatland_rl_service_id=service_id, - visualize=False, - verbose=False - ) + test_env_folder=tests, + flatland_rl_service_id=service_id, + visualize=False, + verbose=False + ) grader.run() diff --git a/flatland/envs/agent_generators.py b/flatland/envs/agent_generators.py new file mode 100644 index 00000000..1f769b7d --- /dev/null +++ b/flatland/envs/agent_generators.py @@ -0,0 +1,182 @@ +"""Agent generators (railway undertaking, "EVU").""" +from typing import Tuple, List, Callable, Mapping, Optional, Any + +import msgpack +import numpy as np + +from flatland.core.grid.grid4_utils import get_new_position +from flatland.core.transition_map import GridTransitionMap +from flatland.envs.agent_utils import EnvAgentStatic + +AgentPosition = Tuple[int, int] +AgentGeneratorProduct = Tuple[List[AgentPosition], List[AgentPosition], List[AgentPosition], List[float]] +AgentGenerator = Callable[[GridTransitionMap, int, Optional[Any]], AgentGeneratorProduct] + + +def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, float] = None) -> List[float]: + """ + Parameters + ------- + nb_agents : int + The number of agents to generate a speed for + speed_ratio_map : Mapping[float,float] + A map of speeds mappint to their ratio of appearance. The ratios must sum up to 1. + + Returns + ------- + List[float] + A list of size nb_agents of speeds with the corresponding probabilistic ratios. + """ + if speed_ratio_map is None: + return [1.0] * nb_agents + + nb_classes = len(speed_ratio_map.keys()) + speed_ratio_map_as_list: List[Tuple[float, float]] = list(speed_ratio_map.items()) + speed_ratios = list(map(lambda t: t[1], speed_ratio_map_as_list)) + speeds = list(map(lambda t: t[0], speed_ratio_map_as_list)) + return list(map(lambda index: speeds[index], np.random.choice(nb_classes, nb_agents, p=speed_ratios))) + + +def complex_rail_generator_agents_placer(speed_ratio_map: Mapping[float, float] = None) -> AgentGenerator: + def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None): + start_goal = hints['start_goal'] + start_dir = hints['start_dir'] + agents_position = [sg[0] for sg in start_goal[:num_agents]] + agents_target = [sg[1] for sg in start_goal[:num_agents]] + agents_direction = start_dir[:num_agents] + + if speed_ratio_map: + speeds = speed_initialization_helper(num_agents, speed_ratio_map) + else: + speeds = [1.0] * len(agents_position) + + return agents_position, agents_direction, agents_target, speeds + + return generator + + +def get_rnd_agents_pos_tgt_dir_on_rail(speed_ratio_map: Mapping[float, float] = None) -> AgentGenerator: + """ + Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target). + + Parameters + ------- + rail : GridTransitionMap + The railway to place agents on. + num_agents : int + The number of agents to generate a speed for + speed_ratio_map : Mapping[float,float] + A map of speeds mappint to their ratio of appearance. The ratios must sum up to 1. + Returns + ------- + Tuple[List[Tuple[int,int]], List[Tuple[int,int]], List[Tuple[int,int]], List[float]] + initial positions, directions, targets speeds + """ + + def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None): + def _path_exists(rail, start, direction, end): + # BFS - Check if a path exists between the 2 nodes + + visited = set() + stack = [(start, direction)] + while stack: + node = stack.pop() + if node[0][0] == end[0] and node[0][1] == end[1]: + return 1 + if node not in visited: + visited.add(node) + moves = rail.get_transitions(node[0][0], node[0][1], node[1]) + for move_index in range(4): + if moves[move_index]: + stack.append((get_new_position(node[0], move_index), + move_index)) + + # If cell is a dead-end, append previous node with reversed + # orientation! + nbits = 0 + tmp = rail.get_full_transitions(node[0][0], node[0][1]) + while tmp > 0: + nbits += (tmp & 1) + tmp = tmp >> 1 + if nbits == 1: + stack.append((node[0], (node[1] + 2) % 4)) + + return 0 + + valid_positions = [] + for r in range(rail.height): + for c in range(rail.width): + if rail.get_full_transitions(r, c) > 0: + valid_positions.append((r, c)) + if len(valid_positions) == 0: + return [], [], [], [] + re_generate = True + while re_generate: + agents_position = [ + valid_positions[i] for i in + np.random.choice(len(valid_positions), num_agents)] + agents_target = [ + valid_positions[i] for i in + np.random.choice(len(valid_positions), num_agents)] + + # agents_direction must be a direction for which a solution is + # guaranteed. + agents_direction = [0] * num_agents + re_generate = False + for i in range(num_agents): + valid_movements = [] + for direction in range(4): + position = agents_position[i] + moves = rail.get_transitions(position[0], position[1], direction) + for move_index in range(4): + if moves[move_index]: + valid_movements.append((direction, move_index)) + + valid_starting_directions = [] + for m in valid_movements: + new_position = get_new_position(agents_position[i], m[1]) + if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0], + agents_target[i]): + valid_starting_directions.append(m[0]) + + if len(valid_starting_directions) == 0: + re_generate = True + else: + agents_direction[i] = valid_starting_directions[ + np.random.choice(len(valid_starting_directions), 1)[0]] + + agents_speed = speed_initialization_helper(num_agents, speed_ratio_map) + return agents_position, agents_direction, agents_target, agents_speed + + return generator + + +def agents_from_file(filename) -> AgentGenerator: + """ + Utility to load pickle file + + Parameters + ------- + input_file : Pickle file generated by env.save() or editor + + Returns + ------- + Tuple[List[Tuple[int,int]], List[Tuple[int,int]], List[Tuple[int,int]], List[float]] + initial positions, directions, targets speeds + """ + + def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None): + with open(filename, "rb") as file_in: + load_data = file_in.read() + data = msgpack.unpackb(load_data, use_list=False) + + # agents are always reset as not moving + agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] + # setup with loaded data + agents_position = [a.position for a in agents_static] + agents_direction = [a.direction for a in agents_static] + agents_target = [a.target for a in agents_static] + + return agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + + return generator diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 79e0ac7d..380bf37f 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -1,4 +1,5 @@ -from typing import Mapping, Tuple, List, Callable +"""Rail generators (infrastructure manager, "Infrastrukturbetreiber").""" +from typing import Callable, Tuple, Any, Optional import msgpack import numpy as np @@ -7,12 +8,12 @@ from flatland.core.grid.grid4_utils import get_direction, mirror from flatland.core.grid.grid_utils import distance_on_rail from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap -from flatland.envs.agent_utils import EnvAgentStatic from flatland.envs.grid4_generators_utils import connect_rail -from flatland.envs.grid4_generators_utils import get_rnd_agents_pos_tgt_dir_on_rail +RailGenerator = Callable[[int, int, int, int], Tuple[GridTransitionMap, Optional[Any]]] -def empty_rail_generator(): + +def empty_rail_generator() -> RailGenerator: """ Returns a generator which returns an empty rail mail with no agents. Primarily used by the editor @@ -24,7 +25,7 @@ def empty_rail_generator(): rail_array = grid_map.grid rail_array.fill(0) - return grid_map, [], [], [], [] + return [grid_map, None] return generator @@ -33,8 +34,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=99999, - seed=0, - speed_initializer: Callable[[int], List[float]] = None): + seed=0) -> RailGenerator: """ Parameters ------- @@ -42,8 +42,6 @@ def complex_rail_generator(nr_start_goal=1, The width (number of cells) of the grid to generate. height : int The height (number of cells) of the grid to generate. - speed_initializer : Callable[[int], List[float]] - Function that returns a list of speeds for the numer of agents given as argument. Returns ------- @@ -56,8 +54,7 @@ def complex_rail_generator(nr_start_goal=1, if num_agents > nr_start_goal: num_agents = nr_start_goal print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents") - rail_trans = RailEnvTransitions() - grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) + grid_map = GridTransitionMap(width=width, height=height, transitions=RailEnvTransitions()) rail_array = grid_map.grid rail_array.fill(0) @@ -81,6 +78,7 @@ def complex_rail_generator(nr_start_goal=1, # - return transition map + list of [start_pos, start_dir, goal_pos] points # + rail_trans = grid_map.transitions start_goal = [] start_dir = [] nr_created = 0 @@ -150,15 +148,10 @@ def complex_rail_generator(nr_start_goal=1, if len(new_path) >= 2: nr_created += 1 - agents_position = [sg[0] for sg in start_goal[:num_agents]] - agents_target = [sg[1] for sg in start_goal[:num_agents]] - agents_direction = start_dir[:num_agents] - - if speed_initializer: - speeds = speed_initializer(num_agents) - else: - speeds = [1.0] * len(agents_position) - return grid_map, agents_position, agents_direction, agents_target, speeds + return grid_map, {'agents_hints': { + 'start_goal': start_goal, + 'start_dir': start_dir + }} return generator @@ -202,22 +195,18 @@ def rail_from_manual_specifications_generator(rail_spec): effective_transition_cell = rail_env_transitions.rotate_transition(basic_type_of_cell_, rotation_cell_) rail.set_transitions((r, c), effective_transition_cell) - agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail( - rail, - num_agents) - - return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + return [rail, None] return generator -def rail_from_file(filename): +def rail_from_file(filename) -> RailGenerator: """ Utility to load pickle file Parameters ------- - input_file : Pickle file generated by env.save() or editor + filename : Pickle file generated by env.save() or editor Returns ------- @@ -235,26 +224,16 @@ def rail_from_file(filename): grid = np.array(data[b"grid"]) rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions) rail.grid = grid - # agents are always reset as not moving - agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] - # setup with loaded data - agents_position = [a.position for a in agents_static] - agents_direction = [a.direction for a in agents_static] - agents_target = [a.target for a in agents_static] if b"distance_maps" in data.keys(): distance_maps = data[b"distance_maps"] if len(distance_maps) > 0: - return rail, agents_position, agents_direction, agents_target, [1.0] * len( - agents_position), distance_maps - else: - return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) - else: - return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + return rail, {'distance_maps': distance_maps} + return [rail, None] return generator -def rail_from_grid_transition_map(rail_map): +def rail_from_grid_transition_map(rail_map) -> RailGenerator: """ Utility to convert a rail given by a GridTransitionMap map with the correct 16-bit transitions specifications. @@ -271,16 +250,12 @@ def rail_from_grid_transition_map(rail_map): """ def generator(width, height, num_agents, num_resets=0): - agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail( - rail_map, - num_agents) - - return rail_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + return [rail_map, None] return generator -def random_rail_generator(cell_type_relative_proportion=[1.0] * 11): +def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGenerator: """ Dummy random level generator: - fill in cells at random in [width-2, height-2] @@ -544,31 +519,6 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11): return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils) return_rail.grid = tmp_rail - agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail( - return_rail, - num_agents) - - return return_rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + return [return_rail, None] return generator - - -def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, float]) -> List[float]: - """ - Parameters - ------- - nb_agents : int - The number of agents to generate a speed for - speed_ratio_map : Mapping[float,float] - A map of speeds mappint to their ratio of appearance. The ratios must sum up to 1. - - Returns - ------- - List[float] - A list of size nb_agents of speeds with the corresponding probabilistic ratios. - """ - nb_classes = len(speed_ratio_map.keys()) - speed_ratio_map_as_list: List[Tuple[float, float]] = list(speed_ratio_map.items()) - speed_ratios = list(map(lambda t: t[1], speed_ratio_map_as_list)) - speeds = list(map(lambda t: t[0], speed_ratio_map_as_list)) - return list(map(lambda index: speeds[index], np.random.choice(nb_classes, nb_agents, p=speed_ratios))) diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index dedd76b6..8bbd0df7 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -5,10 +5,8 @@ Generator functions are functions that take width, height and num_resets as argu a GridTransitionMap object. """ -import numpy as np - from flatland.core.grid.grid4_astar import a_star -from flatland.core.grid.grid4_utils import get_direction, mirror, get_new_position +from flatland.core.grid.grid4_utils import get_direction, mirror def connect_rail(rail_trans, rail_array, start, end): @@ -55,81 +53,3 @@ def connect_rail(rail_trans, rail_array, start, end): current_dir = new_dir return path - - -def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents): - """ - Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target). - - TODO: add extensive documentation, as users may need this function to simplify their custom level generators. - """ - - def _path_exists(rail, start, direction, end): - # BFS - Check if a path exists between the 2 nodes - - visited = set() - stack = [(start, direction)] - while stack: - node = stack.pop() - if node[0][0] == end[0] and node[0][1] == end[1]: - return 1 - if node not in visited: - visited.add(node) - moves = rail.get_transitions(node[0][0], node[0][1], node[1]) - for move_index in range(4): - if moves[move_index]: - stack.append((get_new_position(node[0], move_index), - move_index)) - - # If cell is a dead-end, append previous node with reversed - # orientation! - nbits = 0 - tmp = rail.get_full_transitions(node[0][0], node[0][1]) - while tmp > 0: - nbits += (tmp & 1) - tmp = tmp >> 1 - if nbits == 1: - stack.append((node[0], (node[1] + 2) % 4)) - - return 0 - - valid_positions = [] - for r in range(rail.height): - for c in range(rail.width): - if rail.get_full_transitions(r, c) > 0: - valid_positions.append((r, c)) - - re_generate = True - while re_generate: - agents_position = [ - valid_positions[i] for i in - np.random.choice(len(valid_positions), num_agents)] - agents_target = [ - valid_positions[i] for i in - np.random.choice(len(valid_positions), num_agents)] - - # agents_direction must be a direction for which a solution is - # guaranteed. - agents_direction = [0] * num_agents - re_generate = False - for i in range(num_agents): - valid_movements = [] - for direction in range(4): - position = agents_position[i] - moves = rail.get_transitions(position[0], position[1], direction) - for move_index in range(4): - if moves[move_index]: - valid_movements.append((direction, move_index)) - - valid_starting_directions = [] - for m in valid_movements: - new_position = get_new_position(agents_position[i], m[1]) - if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0], agents_target[i]): - valid_starting_directions.append(m[0]) - - if len(valid_starting_directions) == 0: - re_generate = True - else: - agents_direction[i] = valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]] - - return agents_position, agents_direction, agents_target diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 22812829..27664403 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -11,8 +11,9 @@ import numpy as np from flatland.core.env import Environment from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.agent_generators import get_rnd_agents_pos_tgt_dir_on_rail, AgentGenerator from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent -from flatland.envs.generators import random_rail_generator +from flatland.envs.generators import random_rail_generator, RailGenerator from flatland.envs.observations import TreeObsForRailEnv m.patch() @@ -91,7 +92,8 @@ class RailEnv(Environment): def __init__(self, width, height, - rail_generator=random_rail_generator(), + rail_generator: RailGenerator = random_rail_generator(), + agent_generator: AgentGenerator = get_rnd_agents_pos_tgt_dir_on_rail(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2), max_episode_steps=None, @@ -107,13 +109,12 @@ class RailEnv(Environment): height and agents handles of a rail environment, along with the number of times the env has been reset, and returns a GridTransitionMap object and a list of starting positions, targets, and initial orientations for agent handle. - Implemented functions are: - random_rail_generator : generate a random rail of given size - rail_from_grid_transition_map(rail_map) : generate a rail from - a GridTransitionMap object - rail_from_manual_sp ecifications_generator(rail_spec) : generate a rail from - a rail specifications array - TODO: generate_rail_from_saved_list or from list of ndarray bitmaps --- + The rail_generator can pass a distance map in the hints or information for specific agent_generators. + Implementations can be found in flatland/envs/generators.py + agent_generator : function + The agent_generator function is a function that takes the grid, the number of agents and optional hints + and returns a list of starting positions, targets, initial orientations and speed for all agent handles. + Implementations can be found in flatland/envs/agent_generators.py width : int The width of the rail map. Potentially in the future, a range of widths to sample from. @@ -131,7 +132,8 @@ class RailEnv(Environment): file_name: you can load a pickle file. """ - self.rail_generator = rail_generator + self.rail_generator: RailGenerator = rail_generator + self.agent_generator: AgentGenerator = agent_generator self.rail = None self.width = width self.height = height @@ -213,18 +215,21 @@ class RailEnv(Environment): if replace_agents then regenerate the agents static. Relies on the rail_generator returning agent_static lists (pos, dir, target) """ - tRailAgents = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets) + rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets) - # Check if generator provided a distance map TODO: Make this check safer! - if len(tRailAgents) > 5: - self.obs_builder.distance_map = tRailAgents[-1] + if optionals and 'distance_maps' in optionals: + self.obs_builder.distance_map = optionals['distance_maps'] if regen_rail or self.rail is None: - self.rail = tRailAgents[0] + self.rail = rail self.height, self.width = self.rail.grid.shape if replace_agents: - self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5]) + agents_hints = None + if optionals and 'agents_hints' in optionals: + agents_hints = optionals['agents_hints'] + self.agents_static = EnvAgentStatic.from_lists( + *self.agent_generator(self.rail, self.get_num_agents(), hints=agents_hints)) self.restart_agents() diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py index 12e0c092..dbeb2fb4 100644 --- a/tests/test_distance_map.py +++ b/tests/test_distance_map.py @@ -2,6 +2,7 @@ import numpy as np from flatland.core.grid.grid4 import Grid4Transitions from flatland.core.transition_map import GridTransitionMap +from flatland.envs.agent_generators import get_rnd_agents_pos_tgt_dir_on_rail from flatland.envs.generators import rail_from_grid_transition_map from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv @@ -34,6 +35,7 @@ def test_walker(): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(max_depth=10)), diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index 574705c4..c0e28534 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -4,6 +4,7 @@ import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.envs.agent_generators import get_rnd_agents_pos_tgt_dir_on_rail from flatland.envs.agent_utils import EnvAgent from flatland.envs.generators import rail_from_grid_transition_map from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv @@ -21,6 +22,7 @@ def test_global_obs(): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) @@ -90,6 +92,7 @@ def test_reward_function_conflict(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) @@ -168,6 +171,7 @@ def test_reward_function_waiting(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index 7acd58ed..8c67a42a 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -5,6 +5,7 @@ import pprint import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.envs.agent_generators import get_rnd_agents_pos_tgt_dir_on_rail from flatland.envs.generators import rail_from_grid_transition_map from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv @@ -21,6 +22,7 @@ def test_dummy_predictor(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)), ) @@ -111,6 +113,7 @@ def test_shortest_path_predictor(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) @@ -230,6 +233,7 @@ def test_shortest_path_predictor_conflicts(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index 71dc87ce..9108d614 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -5,6 +5,7 @@ import numpy as np from flatland.core.grid.grid4 import Grid4Transitions from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap +from flatland.envs.agent_generators import get_rnd_agents_pos_tgt_dir_on_rail, complex_rail_generator_agents_placer from flatland.envs.agent_utils import EnvAgent from flatland.envs.agent_utils import EnvAgentStatic from flatland.envs.generators import complex_rail_generator @@ -27,6 +28,7 @@ def test_load_env(): def test_save_load(): env = RailEnv(width=10, height=10, rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=0), + agent_generator=complex_rail_generator_agents_placer(), number_of_agents=2) env.reset() agent_1_pos = env.agents_static[0].position @@ -86,6 +88,7 @@ def test_rail_environment_single_agent(): rail_env = RailEnv(width=3, height=3, rail_generator=rail_from_grid_transition_map(rail), + agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) @@ -165,6 +168,7 @@ def test_dead_end(): rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) @@ -209,6 +213,7 @@ def test_dead_end(): rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 67dcd25c..60a15bb7 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -1,5 +1,6 @@ import numpy as np +from flatland.envs.agent_generators import complex_rail_generator_agents_placer from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv @@ -62,6 +63,7 @@ def test_malfunction_process(): height=20, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0), + agent_generator=complex_rail_generator_agents_placer(), number_of_agents=2, obs_builder_object=SingleAgentNavigationObs(), stochastic_data=stochastic_data) diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index 47aadee7..8703800e 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -1,10 +1,12 @@ import numpy as np +from flatland.envs.agent_generators import complex_rail_generator_agents_placer from flatland.envs.generators import complex_rail_generator from flatland.envs.rail_env import RailEnv np.random.seed(1) + # Use the complex_rail_generator to generate feasible network configurations with corresponding tasks # Training on simple small tasks is the best way to get familiar with the environment # @@ -46,6 +48,7 @@ def test_multi_speed_init(): height=50, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), + agent_generator=complex_rail_generator_agents_placer(), number_of_agents=5) # Initialize the agent with the parameters corresponding to the environment and observation_builder agent = RandomAgent(218, 4) diff --git a/tests/test_speed_classes.py b/tests/test_speed_classes.py index 6ef600d9..67054c81 100644 --- a/tests/test_speed_classes.py +++ b/tests/test_speed_classes.py @@ -1,7 +1,8 @@ """Test speed initialization by a map of speeds and their corresponding ratios.""" import numpy as np -from flatland.envs.generators import speed_initialization_helper, complex_rail_generator +from flatland.envs.agent_generators import speed_initialization_helper, complex_rail_generator_agents_placer +from flatland.envs.generators import complex_rail_generator from flatland.envs.rail_env import RailEnv @@ -17,13 +18,11 @@ def test_speed_initialization_helper(): def test_rail_env_speed_intializer(): speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.1, 5: 0.2} - def my_speed_initializer(nb_agents): - return speed_initialization_helper(nb_agents, speed_ratio_map) - env = RailEnv(width=50, height=50, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, - seed=0, speed_initializer=my_speed_initializer), + seed=0), + agent_generator=complex_rail_generator_agents_placer(), number_of_agents=10) env.reset() actual_speeds = list(map(lambda agent: agent.speed_data['speed'], env.agents)) diff --git a/tests/tests_generators.py b/tests/tests_generators.py index f97b071e..46109c55 100644 --- a/tests/tests_generators.py +++ b/tests/tests_generators.py @@ -3,6 +3,8 @@ import numpy as np +from flatland.envs.agent_generators import get_rnd_agents_pos_tgt_dir_on_rail, complex_rail_generator_agents_placer, \ + agents_from_file from flatland.envs.generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \ random_rail_generator, empty_rail_generator from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv @@ -58,7 +60,8 @@ def test_complex_rail_generator(): env = RailEnv(width=x_dim, height=y_dim, number_of_agents=n_agents, - rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist) + rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist), + agent_generator=complex_rail_generator_agents_placer() ) assert env.get_num_agents() == 2 assert env.rail.grid.shape == (y_dim, x_dim) @@ -69,7 +72,8 @@ def test_complex_rail_generator(): env = RailEnv(width=x_dim, height=y_dim, number_of_agents=n_agents, - rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist) + rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist), + agent_generator=complex_rail_generator_agents_placer() ) assert env.get_num_agents() == 0 assert env.rail.grid.shape == (y_dim, x_dim) @@ -82,7 +86,8 @@ def test_complex_rail_generator(): env = RailEnv(width=x_dim, height=y_dim, number_of_agents=n_agents, - rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist) + rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist), + agent_generator=complex_rail_generator_agents_placer() ) assert env.get_num_agents() == n_agents assert env.rail.grid.shape == (y_dim, x_dim) @@ -94,6 +99,7 @@ def test_rail_from_grid_transition_map(): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), number_of_agents=n_agents ) nr_rail_elements = np.count_nonzero(env.rail.grid) @@ -118,6 +124,7 @@ def tests_rail_from_file(): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), number_of_agents=3, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) @@ -130,6 +137,7 @@ def tests_rail_from_file(): env = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), + agent_generator=agents_from_file(file_name), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) @@ -151,6 +159,7 @@ def tests_rail_from_file(): env2 = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(), number_of_agents=3, obs_builder_object=GlobalObsForRailEnv(), ) @@ -164,6 +173,7 @@ def tests_rail_from_file(): env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name_2), + agent_generator=agents_from_file(file_name_2), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv(), ) @@ -180,6 +190,7 @@ def tests_rail_from_file(): env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), + agent_generator=agents_from_file(file_name), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv(), ) @@ -197,6 +208,7 @@ def tests_rail_from_file(): env4 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name_2), + agent_generator=agents_from_file(file_name_2), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2), ) -- GitLab