diff --git a/examples/complex_rail_benchmark.py b/examples/complex_rail_benchmark.py index 44e4b534c2f2dfa63cab385d009c9afd92285f48..49e550b195ffb15e6554413069369378e80e5f82 100644 --- a/examples/complex_rail_benchmark.py +++ b/examples/complex_rail_benchmark.py @@ -3,8 +3,9 @@ import random import numpy as np -from flatland.envs.generators import complex_rail_generator from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import complex_rail_generator +from flatland.envs.schedule_generators import complex_schedule_generator def run_benchmark(): @@ -15,6 +16,7 @@ def run_benchmark(): # Example generate a random rail env = RailEnv(width=15, height=15, rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12), + schedule_generator=complex_schedule_generator(), number_of_agents=5) n_trials = 20 diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py index 723bb1102092c7d48bd938bbd60d0c5213ffecf6..8b1de6aa4e303469d30983d30333fbfda89c1d1e 100644 --- a/examples/custom_observation_example.py +++ b/examples/custom_observation_example.py @@ -5,10 +5,11 @@ import numpy as np from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid_utils import coordinate_to_position -from flatland.envs.generators import random_rail_generator, complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import random_rail_generator, complex_rail_generator +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool random.seed(100) @@ -20,6 +21,7 @@ class SimpleObs(ObservationBuilder): Simplest observation builder. The object returns observation vectors with 5 identical components, all equal to the ID of the respective agent. """ + def __init__(self): self.observation_space = [5] @@ -53,6 +55,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector will be [1, 0, 0]. """ + def __init__(self): super().__init__(max_depth=0) self.observation_space = [3] @@ -90,6 +93,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): env = RailEnv(width=7, height=7, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0), + schedule_generator=complex_schedule_generator(), number_of_agents=1, obs_builder_object=SingleAgentNavigationObs()) @@ -97,8 +101,8 @@ obs = env.reset() env_renderer = RenderTool(env, gl="PILSVG") env_renderer.render_env(show=True, frames=True, show_observations=True) for step in range(100): - action = np.argmax(obs[0])+1 - obs, all_rewards, done, _ = env.step({0:action}) + action = np.argmax(obs[0]) + 1 + obs, all_rewards, done, _ = env.step({0: action}) print("Rewards: ", all_rewards, " [done=", done, "]") env_renderer.render_env(show=True, frames=True, show_observations=True) time.sleep(0.1) @@ -200,6 +204,7 @@ CustomObsBuilder = ObservePredictions(CustomPredictor) env = RailEnv(width=10, height=10, rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, seed=0), + schedule_generator=complex_schedule_generator(), number_of_agents=3, obs_builder_object=CustomObsBuilder) diff --git a/examples/custom_railmap_example.py b/examples/custom_railmap_example.py index 515d6c1b0469b7fbd9bad8cd82a40db7766f6219..04da66904fda1a58847a4acc510d7fc4e4e86887 100644 --- a/examples/custom_railmap_example.py +++ b/examples/custom_railmap_example.py @@ -1,30 +1,41 @@ import random +from typing import Any import numpy as np from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import RailGenerator, RailGeneratorProduct +from flatland.envs.schedule_generators import ScheduleGenerator, ScheduleGeneratorProduct from flatland.utils.rendertools import RenderTool random.seed(100) np.random.seed(100) -def custom_rail_generator(): - def generator(width, height, num_agents=0, num_resets=0): +def custom_rail_generator() -> RailGenerator: + def generator(width: int, height: int, num_agents: int = 0, num_resets: int = 0) -> RailGeneratorProduct: rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) rail_array = grid_map.grid rail_array.fill(0) new_tran = rail_trans.set_transition(1, 1, 1, 1) print(new_tran) + rail_array[0, 0] = new_tran + rail_array[0, 1] = new_tran + return grid_map, None + + return generator + + +def custom_schedule_generator() -> ScheduleGenerator: + def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct: agents_positions = [] agents_direction = [] agents_target = [] - rail_array[0, 0] = new_tran - rail_array[0, 1] = new_tran - return grid_map, agents_positions, agents_direction, agents_target + speeds = [] + return agents_positions, agents_direction, agents_target, speeds return generator diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py index 2c0f814576caef84471d20c91dd92d23d4db02ac..50ea74b84ac9851e88e48bcd32b914e69bc7dd34 100644 --- a/examples/debugging_example_DELETE.py +++ b/examples/debugging_example_DELETE.py @@ -3,14 +3,16 @@ import time import numpy as np -from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import complex_rail_generator +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool random.seed(1) np.random.seed(1) + class SingleAgentNavigationObs(TreeObsForRailEnv): """ We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute @@ -21,6 +23,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector will be [1, 0, 0]. """ + def __init__(self): super().__init__(max_depth=0) self.observation_space = [3] @@ -58,6 +61,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): env = RailEnv(width=14, height=14, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0), + schedule_generator=complex_schedule_generator(), number_of_agents=2, obs_builder_object=SingleAgentNavigationObs()) @@ -67,11 +71,11 @@ env_renderer.render_env(show=True, frames=True, show_observations=False) for step in range(100): actions = {} for i in range(len(obs)): - actions[i] = np.argmax(obs[i])+1 + actions[i] = np.argmax(obs[i]) + 1 - if step%5 == 0: + if step % 5 == 0: print("Agent halts") - actions[0] = 4 # Halt + actions[0] = 4 # Halt obs, all_rewards, done, _ = env.step(actions) if env.agents[0].malfunction_data['malfunction'] > 0: @@ -82,4 +86,3 @@ for step in range(100): if done["__all__"]: break env_renderer.close_window() - diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 916e50b20b10a02c43c5b1da8bc0728930b8c535..71a185c765bcab831e7b104124a164bcf2398b14 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -1,9 +1,10 @@ import numpy as np +from flatland.envs.rail_generators import sparse_rail_generator -from flatland.envs.generators import sparse_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.utils.rendertools import RenderTool np.random.seed(1) @@ -31,6 +32,7 @@ env = RailEnv(width=20, realistic_mode=True, enhance_intersection=True ), + schedule_generator=sparse_schedule_generator(), number_of_agents=5, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=TreeObservation) @@ -75,7 +77,6 @@ class RandomAgent: # Set action space to 4 to remove stop action agent = RandomAgent(218, 4) - # Empty dictionary for all agent action action_dict = dict() diff --git a/examples/simple_example_1.py b/examples/simple_example_1.py index 7956c34fd4a5b94859a4b64441450afe2114133c..fbadbd657c36fa1dadf0bca65cff3e9cccd269ea 100644 --- a/examples/simple_example_1.py +++ b/examples/simple_example_1.py @@ -1,5 +1,5 @@ -from flatland.envs.generators import rail_from_manual_specifications_generator from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import rail_from_manual_specifications_generator from flatland.utils.rendertools import RenderTool # Example generate a rail given a manual specification, diff --git a/examples/simple_example_2.py b/examples/simple_example_2.py index 994c7deda1569b77d4adac8a17fa9ebe14b27ef6..6db9ba5abbd0999ef3896e733516ed6b3e498bae 100644 --- a/examples/simple_example_2.py +++ b/examples/simple_example_2.py @@ -2,8 +2,8 @@ import random import numpy as np -from flatland.envs.generators import random_rail_generator from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import random_rail_generator from flatland.utils.rendertools import RenderTool random.seed(100) diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py index 5aa03d8f95a7079b708baea1e2ddce27e9a46554..6df6d4af3076b3d9659aadbd55296b667dc7d6db 100644 --- a/examples/simple_example_3.py +++ b/examples/simple_example_3.py @@ -2,9 +2,10 @@ import random import numpy as np -from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import complex_rail_generator +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool random.seed(1) @@ -13,6 +14,7 @@ np.random.seed(1) env = RailEnv(width=7, height=7, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), + schedule_generator=complex_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2)) diff --git a/examples/training_example.py b/examples/training_example.py index d125be1587a56025ba1cd3f78b28ba3976f01fbf..df93479f5a5ee05abfcb1a98b07ef052bffc2bd4 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -1,9 +1,10 @@ import numpy as np -from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv, LocalObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import complex_rail_generator +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool np.random.seed(1) @@ -16,11 +17,13 @@ LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2) env = RailEnv(width=20, height=20, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), + schedule_generator=complex_schedule_generator(), obs_builder_object=TreeObservation, number_of_agents=3) env_renderer = RenderTool(env, gl="PILSVG", ) + # Import your own Agent or use RLlib to train agents on Flatland # As an example we use a random agent here diff --git a/flatland/cli.py b/flatland/cli.py index 32e8d9dc786b0412795694fc985c90aa55fc2e91..47c450dba803fac17bc13979663ef04e4c0db899 100644 --- a/flatland/cli.py +++ b/flatland/cli.py @@ -2,29 +2,33 @@ """Console script for flatland.""" import sys +import time + import click import numpy as np -import time -from flatland.envs.generators import complex_rail_generator +import redis + from flatland.envs.rail_env import RailEnv -from flatland.utils.rendertools import RenderTool +from flatland.envs.rail_generators import complex_rail_generator +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.evaluators.service import FlatlandRemoteEvaluationService -import redis +from flatland.utils.rendertools import RenderTool @click.command() def demo(args=None): """Demo script to check installation""" env = RailEnv( - width=15, - height=15, - rail_generator=complex_rail_generator( - nr_start_goal=10, - nr_extra=1, - min_dist=8, - max_dist=99999), - number_of_agents=5) - + width=15, + height=15, + rail_generator=complex_rail_generator( + nr_start_goal=10, + nr_extra=1, + min_dist=8, + max_dist=99999), + schedule_generator=complex_schedule_generator(), + number_of_agents=5) + env._max_episode_steps = int(15 * (env.width + env.height)) env_renderer = RenderTool(env) @@ -52,12 +56,12 @@ def demo(args=None): @click.command() -@click.option('--tests', +@click.option('--tests', type=click.Path(exists=True), help="Path to folder containing Flatland tests", required=True ) -@click.option('--service_id', +@click.option('--service_id', default="FLATLAND_RL_SERVICE_ID", help="Evaluation Service ID. This has to match the service id on the client.", required=False @@ -70,14 +74,14 @@ def evaluator(tests, service_id): raise Exception( "\nRedis server does not seem to be running on your localhost.\n" "Please ensure that you have a redis server running on your localhost" - ) - + ) + grader = FlatlandRemoteEvaluationService( - test_env_folder=tests, - flatland_rl_service_id=service_id, - visualize=False, - verbose=False - ) + test_env_folder=tests, + flatland_rl_service_id=service_id, + visualize=False, + verbose=False + ) grader.run() diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index 0055b243668a4f3cd562958a59f52ba830af1c86..996bd73ad9de598eb162a937c135681675119ad3 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -5,10 +5,8 @@ Generator functions are functions that take width, height and num_resets as argu a GridTransitionMap object. """ -import numpy as np - from flatland.core.grid.grid4_astar import a_star -from flatland.core.grid.grid4_utils import get_direction, mirror, get_new_position +from flatland.core.grid.grid4_utils import get_direction, mirror def connect_rail(rail_trans, rail_array, start, end): @@ -195,81 +193,3 @@ def connect_to_nodes(rail_trans, rail_array, start, end): current_dir = new_dir return path - - -def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents): - """ - Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target). - - TODO: add extensive documentation, as users may need this function to simplify their custom level generators. - """ - - def _path_exists(rail, start, direction, end): - # BFS - Check if a path exists between the 2 nodes - - visited = set() - stack = [(start, direction)] - while stack: - node = stack.pop() - if node[0][0] == end[0] and node[0][1] == end[1]: - return 1 - if node not in visited: - visited.add(node) - moves = rail.get_transitions(node[0][0], node[0][1], node[1]) - for move_index in range(4): - if moves[move_index]: - stack.append((get_new_position(node[0], move_index), - move_index)) - - # If cell is a dead-end, append previous node with reversed - # orientation! - nbits = 0 - tmp = rail.get_full_transitions(node[0][0], node[0][1]) - while tmp > 0: - nbits += (tmp & 1) - tmp = tmp >> 1 - if nbits == 1: - stack.append((node[0], (node[1] + 2) % 4)) - - return 0 - - valid_positions = [] - for r in range(rail.height): - for c in range(rail.width): - if rail.get_full_transitions(r, c) > 0: - valid_positions.append((r, c)) - - re_generate = True - while re_generate: - agents_position = [ - valid_positions[i] for i in - np.random.choice(len(valid_positions), num_agents)] - agents_target = [ - valid_positions[i] for i in - np.random.choice(len(valid_positions), num_agents)] - - # agents_direction must be a direction for which a solution is - # guaranteed. - agents_direction = [0] * num_agents - re_generate = False - for i in range(num_agents): - valid_movements = [] - for direction in range(4): - position = agents_position[i] - moves = rail.get_transitions(position[0], position[1], direction) - for move_index in range(4): - if moves[move_index]: - valid_movements.append((direction, move_index)) - - valid_starting_directions = [] - for m in valid_movements: - new_position = get_new_position(agents_position[i], m[1]) - if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0], agents_target[i]): - valid_starting_directions.append(m[0]) - - if len(valid_starting_directions) == 0: - re_generate = True - else: - agents_direction[i] = valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]] - - return agents_position, agents_direction, agents_target diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 6e6665af88c9e8b31e1a689815edb7aaada342f9..a61ef02207174d04489b5311dc042b7c06db1412 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -13,8 +13,9 @@ from flatland.core.env import Environment from flatland.core.grid.grid4_utils import get_new_position from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent -from flatland.envs.generators import random_rail_generator from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.rail_generators import random_rail_generator, RailGenerator +from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator m.patch() @@ -92,7 +93,8 @@ class RailEnv(Environment): def __init__(self, width, height, - rail_generator=random_rail_generator(), + rail_generator: RailGenerator = random_rail_generator(), + schedule_generator: ScheduleGenerator = random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2), max_episode_steps=None, @@ -108,13 +110,12 @@ class RailEnv(Environment): height and agents handles of a rail environment, along with the number of times the env has been reset, and returns a GridTransitionMap object and a list of starting positions, targets, and initial orientations for agent handle. - Implemented functions are: - random_rail_generator : generate a random rail of given size - rail_from_grid_transition_map(rail_map) : generate a rail from - a GridTransitionMap object - rail_from_manual_sp ecifications_generator(rail_spec) : generate a rail from - a rail specifications array - TODO: generate_rail_from_saved_list or from list of ndarray bitmaps --- + The rail_generator can pass a distance map in the hints or information for specific schedule_generators. + Implementations can be found in flatland/envs/rail_generators.py + schedule_generator : function + The schedule_generator function is a function that takes the grid, the number of agents and optional hints + and returns a list of starting positions, targets, initial orientations and speed for all agent handles. + Implementations can be found in flatland/envs/schedule_generators.py width : int The width of the rail map. Potentially in the future, a range of widths to sample from. @@ -132,6 +133,8 @@ class RailEnv(Environment): file_name: you can load a pickle file. """ + self.rail_generator: RailGenerator = rail_generator + self.schedule_generator: ScheduleGenerator = schedule_generator self.rail_generator = rail_generator self.rail: GridTransitionMap = None self.width = width @@ -214,14 +217,13 @@ class RailEnv(Environment): if replace_agents then regenerate the agents static. Relies on the rail_generator returning agent_static lists (pos, dir, target) """ - tRailAgents = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets) + rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets) - # Check if generator provided a distance map TODO: Make this check safer! - if len(tRailAgents) > 5: - self.obs_builder.distance_map = tRailAgents[-1] + if optionals and 'distance_maps' in optionals: + self.obs_builder.distance_map = optionals['distance_maps'] if regen_rail or self.rail is None: - self.rail = tRailAgents[0] + self.rail = rail self.height, self.width = self.rail.grid.shape for r in range(self.height): for c in range(self.width): @@ -231,7 +233,11 @@ class RailEnv(Environment): warnings.warn("Invalid grid at {} -> {}".format(rcPos, check)) if replace_agents: - self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5]) + agents_hints = None + if optionals and 'agents_hints' in optionals: + agents_hints = optionals['agents_hints'] + self.agents_static = EnvAgentStatic.from_lists( + *self.schedule_generator(self.rail, self.get_num_agents(), hints=agents_hints)) self.restart_agents() diff --git a/flatland/envs/generators.py b/flatland/envs/rail_generators.py similarity index 87% rename from flatland/envs/generators.py rename to flatland/envs/rail_generators.py index 525db36e8c5a09a451c0c59c1d03f1352f66e827..ed507dca9de9b3e90d412e77a7204037a5a20975 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/rail_generators.py @@ -1,4 +1,6 @@ +"""Rail generators (infrastructure manager, "Infrastrukturbetreiber").""" import warnings +from typing import Callable, Tuple, Any, Optional import msgpack import numpy as np @@ -7,29 +9,34 @@ from flatland.core.grid.grid4_utils import get_direction, mirror from flatland.core.grid.grid_utils import distance_on_rail from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap -from flatland.envs.agent_utils import EnvAgentStatic from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes, connect_from_nodes -from flatland.envs.grid4_generators_utils import get_rnd_agents_pos_tgt_dir_on_rail +RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Any]] +RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct] -def empty_rail_generator(): + +def empty_rail_generator() -> RailGenerator: """ Returns a generator which returns an empty rail mail with no agents. Primarily used by the editor """ - def generator(width, height, num_agents=0, num_resets=0): + def generator(width: int, height: int, num_agents: int = 0, num_resets: int = 0) -> RailGeneratorProduct: rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) rail_array = grid_map.grid rail_array.fill(0) - return grid_map, [], [], [], [] + return grid_map, None return generator -def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=99999, seed=0): +def complex_rail_generator(nr_start_goal=1, + nr_extra=100, + min_dist=20, + max_dist=99999, + seed=0) -> RailGenerator: """ Parameters ------- @@ -49,8 +56,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist= if num_agents > nr_start_goal: num_agents = nr_start_goal print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents") - rail_trans = RailEnvTransitions() - grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) + grid_map = GridTransitionMap(width=width, height=height, transitions=RailEnvTransitions()) rail_array = grid_map.grid rail_array.fill(0) @@ -74,6 +80,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist= # - return transition map + list of [start_pos, start_dir, goal_pos] points # + rail_trans = grid_map.transitions start_goal = [] start_dir = [] nr_created = 0 @@ -143,11 +150,10 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist= if len(new_path) >= 2: nr_created += 1 - agents_position = [sg[0] for sg in start_goal[:num_agents]] - agents_target = [sg[1] for sg in start_goal[:num_agents]] - agents_direction = start_dir[:num_agents] - - return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + return grid_map, {'agents_hints': { + 'start_goal': start_goal, + 'start_dir': start_dir + }} return generator @@ -191,22 +197,18 @@ def rail_from_manual_specifications_generator(rail_spec): effective_transition_cell = rail_env_transitions.rotate_transition(basic_type_of_cell_, rotation_cell_) rail.set_transitions((r, c), effective_transition_cell) - agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail( - rail, - num_agents) - - return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + return [rail, None] return generator -def rail_from_file(filename): +def rail_from_file(filename) -> RailGenerator: """ Utility to load pickle file Parameters ------- - input_file : Pickle file generated by env.save() or editor + filename : Pickle file generated by env.save() or editor Returns ------- @@ -224,26 +226,16 @@ def rail_from_file(filename): grid = np.array(data[b"grid"]) rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions) rail.grid = grid - # agents are always reset as not moving - agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] - # setup with loaded data - agents_position = [a.position for a in agents_static] - agents_direction = [a.direction for a in agents_static] - agents_target = [a.target for a in agents_static] if b"distance_maps" in data.keys(): distance_maps = data[b"distance_maps"] if len(distance_maps) > 0: - return rail, agents_position, agents_direction, agents_target, [1.0] * len( - agents_position), distance_maps - else: - return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) - else: - return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + return rail, {'distance_maps': distance_maps} + return [rail, None] return generator -def rail_from_grid_transition_map(rail_map): +def rail_from_grid_transition_map(rail_map) -> RailGenerator: """ Utility to convert a rail given by a GridTransitionMap map with the correct 16-bit transitions specifications. @@ -259,17 +251,13 @@ def rail_from_grid_transition_map(rail_map): Generator function that always returns the given `rail_map' object. """ - def generator(width, height, num_agents, num_resets=0): - agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail( - rail_map, - num_agents) - - return rail_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct: + return rail_map, None return generator -def random_rail_generator(cell_type_relative_proportion=[1.0] * 11): +def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGenerator: """ Dummy random level generator: - fill in cells at random in [width-2, height-2] @@ -301,7 +289,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11): The matrix with the correct 16-bit bitmaps for each cell. """ - def generator(width, height, num_agents, num_resets=0): + def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct: t_utils = RailEnvTransitions() transition_probability = cell_type_relative_proportion @@ -533,11 +521,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11): return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils) return_rail.grid = tmp_rail - agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail( - return_rail, - num_agents) - - return return_rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + return return_rail, None return generator @@ -802,48 +786,9 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 else: num_agents -= 1 - # Place agents and targets within available train stations - agents_position = [] - agents_target = [] - agents_direction = [] - - for agent_idx in range(num_agents): - # Set target for agent - current_target_node = agent_start_targets_nodes[agent_idx][1] - target_station_idx = np.random.randint(len(train_stations[current_target_node])) - target = train_stations[current_target_node][target_station_idx] - tries = 0 - while (target[0], target[1]) in agents_target: - target_station_idx = np.random.randint(len(train_stations[current_target_node])) - target = train_stations[current_target_node][target_station_idx] - tries += 1 - if tries > 100: - warnings.warn("Could not set target position, removing an agent") - break - agents_target.append((target[0], target[1])) - - # Set start for agent - current_start_node = agent_start_targets_nodes[agent_idx][0] - start_station_idx = np.random.randint(len(train_stations[current_start_node])) - start = train_stations[current_start_node][start_station_idx] - tries = 0 - while (start[0], start[1]) in agents_position: - tries += 1 - if tries > 100: - warnings.warn("Could not set start position, please change initial parameters!!!!") - break - start_station_idx = np.random.randint(len(train_stations[current_start_node])) - start = train_stations[current_start_node][start_station_idx] - - agents_position.append((start[0], start[1])) - - # Orient the agent correctly - for orientation in range(4): - transitions = grid_map.get_transitions(start[0], start[1], orientation) - if any(transitions) > 0: - agents_direction.append(orientation) - continue - - return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + return grid_map, {'agents_hints': { + 'agent_start_targets_nodes': agent_start_targets_nodes, + 'train_stations': train_stations + }} return generator diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py new file mode 100644 index 0000000000000000000000000000000000000000..0ebc6c71c17db308789a4baf0ec99729ec9991e8 --- /dev/null +++ b/flatland/envs/schedule_generators.py @@ -0,0 +1,238 @@ +"""Schedule generators (railway undertaking, "EVU").""" +import warnings +from typing import Tuple, List, Callable, Mapping, Optional, Any + +import msgpack +import numpy as np + +from flatland.core.grid.grid4_utils import get_new_position +from flatland.core.transition_map import GridTransitionMap +from flatland.envs.agent_utils import EnvAgentStatic + +AgentPosition = Tuple[int, int] +ScheduleGeneratorProduct = Tuple[List[AgentPosition], List[AgentPosition], List[AgentPosition], List[float]] +ScheduleGenerator = Callable[[GridTransitionMap, int, Optional[Any]], ScheduleGeneratorProduct] + + +def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, float] = None) -> List[float]: + """ + Parameters + ------- + nb_agents : int + The number of agents to generate a speed for + speed_ratio_map : Mapping[float,float] + A map of speeds mappint to their ratio of appearance. The ratios must sum up to 1. + + Returns + ------- + List[float] + A list of size nb_agents of speeds with the corresponding probabilistic ratios. + """ + if speed_ratio_map is None: + return [1.0] * nb_agents + + nb_classes = len(speed_ratio_map.keys()) + speed_ratio_map_as_list: List[Tuple[float, float]] = list(speed_ratio_map.items()) + speed_ratios = list(map(lambda t: t[1], speed_ratio_map_as_list)) + speeds = list(map(lambda t: t[0], speed_ratio_map_as_list)) + return list(map(lambda index: speeds[index], np.random.choice(nb_classes, nb_agents, p=speed_ratios))) + + +def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator: + def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None): + start_goal = hints['start_goal'] + start_dir = hints['start_dir'] + agents_position = [sg[0] for sg in start_goal[:num_agents]] + agents_target = [sg[1] for sg in start_goal[:num_agents]] + agents_direction = start_dir[:num_agents] + + if speed_ratio_map: + speeds = speed_initialization_helper(num_agents, speed_ratio_map) + else: + speeds = [1.0] * len(agents_position) + + return agents_position, agents_direction, agents_target, speeds + + return generator + + +def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator: + def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None): + train_stations = hints['train_stations'] + agent_start_targets_nodes = hints['agent_start_targets_nodes'] + # Place agents and targets within available train stations + agents_position = [] + agents_target = [] + agents_direction = [] + for agent_idx in range(num_agents): + # Set target for agent + current_target_node = agent_start_targets_nodes[agent_idx][1] + target_station_idx = np.random.randint(len(train_stations[current_target_node])) + target = train_stations[current_target_node][target_station_idx] + tries = 0 + while (target[0], target[1]) in agents_target: + target_station_idx = np.random.randint(len(train_stations[current_target_node])) + target = train_stations[current_target_node][target_station_idx] + tries += 1 + if tries > 100: + warnings.warn("Could not set target position, removing an agent") + break + agents_target.append((target[0], target[1])) + + # Set start for agent + current_start_node = agent_start_targets_nodes[agent_idx][0] + start_station_idx = np.random.randint(len(train_stations[current_start_node])) + start = train_stations[current_start_node][start_station_idx] + tries = 0 + while (start[0], start[1]) in agents_position: + tries += 1 + if tries > 100: + warnings.warn("Could not set start position, please change initial parameters!!!!") + break + start_station_idx = np.random.randint(len(train_stations[current_start_node])) + start = train_stations[current_start_node][start_station_idx] + + agents_position.append((start[0], start[1])) + + # Orient the agent correctly + for orientation in range(4): + transitions = rail.get_transitions(start[0], start[1], orientation) + if any(transitions) > 0: + agents_direction.append(orientation) + continue + + if speed_ratio_map: + speeds = speed_initialization_helper(num_agents, speed_ratio_map) + else: + speeds = [1.0] * len(agents_position) + + return agents_position, agents_direction, agents_target, speeds + + return generator + + +def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator: + """ + Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target). + + Parameters + ------- + rail : GridTransitionMap + The railway to place agents on. + num_agents : int + The number of agents to generate a speed for + speed_ratio_map : Mapping[float,float] + A map of speeds mappint to their ratio of appearance. The ratios must sum up to 1. + Returns + ------- + Tuple[List[Tuple[int,int]], List[Tuple[int,int]], List[Tuple[int,int]], List[float]] + initial positions, directions, targets speeds + """ + + def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct: + def _path_exists(rail, start, direction, end): + # BFS - Check if a path exists between the 2 nodes + + visited = set() + stack = [(start, direction)] + while stack: + node = stack.pop() + if node[0][0] == end[0] and node[0][1] == end[1]: + return 1 + if node not in visited: + visited.add(node) + moves = rail.get_transitions(node[0][0], node[0][1], node[1]) + for move_index in range(4): + if moves[move_index]: + stack.append((get_new_position(node[0], move_index), + move_index)) + + # If cell is a dead-end, append previous node with reversed + # orientation! + nbits = 0 + tmp = rail.get_full_transitions(node[0][0], node[0][1]) + while tmp > 0: + nbits += (tmp & 1) + tmp = tmp >> 1 + if nbits == 1: + stack.append((node[0], (node[1] + 2) % 4)) + + return 0 + + valid_positions = [] + for r in range(rail.height): + for c in range(rail.width): + if rail.get_full_transitions(r, c) > 0: + valid_positions.append((r, c)) + if len(valid_positions) == 0: + return [], [], [], [] + re_generate = True + while re_generate: + agents_position = [ + valid_positions[i] for i in + np.random.choice(len(valid_positions), num_agents)] + agents_target = [ + valid_positions[i] for i in + np.random.choice(len(valid_positions), num_agents)] + + # agents_direction must be a direction for which a solution is + # guaranteed. + agents_direction = [0] * num_agents + re_generate = False + for i in range(num_agents): + valid_movements = [] + for direction in range(4): + position = agents_position[i] + moves = rail.get_transitions(position[0], position[1], direction) + for move_index in range(4): + if moves[move_index]: + valid_movements.append((direction, move_index)) + + valid_starting_directions = [] + for m in valid_movements: + new_position = get_new_position(agents_position[i], m[1]) + if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0], + agents_target[i]): + valid_starting_directions.append(m[0]) + + if len(valid_starting_directions) == 0: + re_generate = True + else: + agents_direction[i] = valid_starting_directions[ + np.random.choice(len(valid_starting_directions), 1)[0]] + + agents_speed = speed_initialization_helper(num_agents, speed_ratio_map) + return agents_position, agents_direction, agents_target, agents_speed + + return generator + + +def agents_from_file(filename) -> ScheduleGenerator: + """ + Utility to load pickle file + + Parameters + ------- + input_file : Pickle file generated by env.save() or editor + + Returns + ------- + Tuple[List[Tuple[int,int]], List[Tuple[int,int]], List[Tuple[int,int]], List[float]] + initial positions, directions, targets speeds + """ + + def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct: + with open(filename, "rb") as file_in: + load_data = file_in.read() + data = msgpack.unpackb(load_data, use_list=False) + + # agents are always reset as not moving + agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] + # setup with loaded data + agents_position = [a.position for a in agents_static] + agents_direction = [a.direction for a in agents_static] + agents_target = [a.target for a in agents_static] + + return agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + + return generator diff --git a/flatland/evaluators/client.py b/flatland/evaluators/client.py index a4968c0c8e827c060e0e3f7de0cf28cc0658089b..f2dac1c8705a651e2a7be026b4bd82a961efbbdd 100644 --- a/flatland/evaluators/client.py +++ b/flatland/evaluators/client.py @@ -1,18 +1,21 @@ -import redis +import hashlib import json +import logging import os -import numpy as np +import random +import time + import msgpack import msgpack_numpy as m -import hashlib -import random -from flatland.evaluators import messages -from flatland.envs.rail_env import RailEnv -from flatland.envs.generators import rail_from_file +import numpy as np +import redis + from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv -import time -import logging +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import rail_from_file +from flatland.evaluators import messages + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) m.patch() @@ -22,8 +25,8 @@ def are_dicts_equal(d1, d2): """ return True if all keys and values are the same """ return all(k in d2 and d1[k] == d2[k] for k in d1) \ - and all(k in d1 and d1[k] == d2[k] - for k in d2) + and all(k in d1 and d1[k] == d2[k] + for k in d2) class FlatlandRemoteClient(object): @@ -41,39 +44,40 @@ class FlatlandRemoteClient(object): where `service_id` is either provided as an `env` variable or is instantiated to "flatland_rl_redis_service_id" """ - def __init__(self, - remote_host='127.0.0.1', - remote_port=6379, - remote_db=0, - remote_password=None, - test_envs_root=None, - verbose=False): + + def __init__(self, + remote_host='127.0.0.1', + remote_port=6379, + remote_db=0, + remote_password=None, + test_envs_root=None, + verbose=False): self.remote_host = remote_host self.remote_port = remote_port self.remote_db = remote_db self.remote_password = remote_password self.redis_pool = redis.ConnectionPool( - host=remote_host, - port=remote_port, - db=remote_db, - password=remote_password) + host=remote_host, + port=remote_port, + db=remote_db, + password=remote_password) self.namespace = "flatland-rl" self.service_id = os.getenv( - 'FLATLAND_RL_SERVICE_ID', - 'FLATLAND_RL_SERVICE_ID' - ) + 'FLATLAND_RL_SERVICE_ID', + 'FLATLAND_RL_SERVICE_ID' + ) self.command_channel = "{}::{}::commands".format( - self.namespace, - self.service_id - ) + self.namespace, + self.service_id + ) if test_envs_root: self.test_envs_root = test_envs_root else: self.test_envs_root = os.getenv( - 'AICROWD_TESTS_FOLDER', - '/tmp/flatland_envs' - ) + 'AICROWD_TESTS_FOLDER', + '/tmp/flatland_envs' + ) self.verbose = verbose @@ -85,12 +89,12 @@ class FlatlandRemoteClient(object): def _generate_response_channel(self): random_hash = hashlib.md5( - "{}".format( - random.randint(0, 10**10) - ).encode('utf-8')).hexdigest() + "{}".format( + random.randint(0, 10 ** 10) + ).encode('utf-8')).hexdigest() response_channel = "{}::{}::response::{}".format(self.namespace, - self.service_id, - random_hash) + self.service_id, + random_hash) return response_channel def _blocking_request(self, _request): @@ -124,9 +128,9 @@ class FlatlandRemoteClient(object): if self.verbose: print("Response : ", _response) _response = msgpack.unpackb( - _response, - object_hook=m.decode, - encoding="utf8") + _response, + object_hook=m.decode, + encoding="utf8") if _response['type'] == messages.FLATLAND_RL.ERROR: raise Exception(str(_response["payload"])) else: @@ -181,7 +185,7 @@ class FlatlandRemoteClient(object): "Did you remember to set the AICROWD_TESTS_FOLDER environment variable " "to point to the location of the Tests folder ? \n" "We are currently looking at `{}` for the tests".format(self.test_envs_root) - ) + ) print("Current env path : ", test_env_file_path) self.env = RailEnv( width=1, @@ -207,7 +211,7 @@ class FlatlandRemoteClient(object): _request['payload']['action'] = action _response = self._blocking_request(_request) _payload = _response['payload'] - + # remote_observation = _payload['observation'] remote_reward = _payload['reward'] remote_done = _payload['done'] @@ -216,14 +220,14 @@ class FlatlandRemoteClient(object): # Replicate the action in the local env local_observation, local_reward, local_done, local_info = \ self.env.step(action) - + print(local_reward) if not are_dicts_equal(remote_reward, local_reward): raise Exception("local and remote `reward` are diverging") print(remote_reward, local_reward) if not are_dicts_equal(remote_done, local_done): raise Exception("local and remote `done` are diverging") - + # Return local_observation instead of remote_observation # as the remote_observation is build using a dummy observation # builder @@ -250,21 +254,23 @@ class FlatlandRemoteClient(object): if __name__ == "__main__": remote_client = FlatlandRemoteClient() + def my_controller(obs, _env): _action = {} for _idx, _ in enumerate(_env.agents): _action[_idx] = np.random.randint(0, 5) return _action - + + my_observation_builder = TreeObsForRailEnv(max_depth=3, - predictor=ShortestPathPredictorForRailEnv()) + predictor=ShortestPathPredictorForRailEnv()) episode = 0 obs = True - while obs: + while obs: obs = remote_client.env_create( - obs_builder_object=my_observation_builder - ) + obs_builder_object=my_observation_builder + ) if not obs: """ The remote env returns False as the first obs @@ -285,7 +291,5 @@ if __name__ == "__main__": print("Reward : ", sum(list(all_rewards.values()))) break - print("Evaluation Complete...") + print("Evaluation Complete...") print(remote_client.submit()) - - diff --git a/flatland/evaluators/service.py b/flatland/evaluators/service.py index 3ad0a97598c8beb66fc164eb45b670c87f3c96f9..8967b52d9d6ee70a7eb8af257ef6b4e25b531314 100644 --- a/flatland/evaluators/service.py +++ b/flatland/evaluators/service.py @@ -1,24 +1,26 @@ #!/usr/bin/env python from __future__ import print_function -import redis -from flatland.envs.generators import rail_from_file -from flatland.envs.rail_env import RailEnv -from flatland.core.env_observation_builder import DummyObservationBuilder -from flatland.evaluators import messages -from flatland.evaluators import aicrowd_helpers -from flatland.utils.rendertools import RenderTool -import numpy as np -import msgpack -import msgpack_numpy as m -import os + import glob +import os +import random import shutil import time import traceback + import crowdai_api +import msgpack +import msgpack_numpy as m +import numpy as np +import redis import timeout_decorator -import random +from flatland.core.env_observation_builder import DummyObservationBuilder +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import rail_from_file +from flatland.evaluators import aicrowd_helpers +from flatland.evaluators import messages +from flatland.utils.rendertools import RenderTool use_signals_in_timeout = True if os.name == 'nt': @@ -35,7 +37,7 @@ m.patch() ######################################################## # CONSTANTS ######################################################## -PER_STEP_TIMEOUT = 10*60 # 5 minutes +PER_STEP_TIMEOUT = 10 * 60 # 5 minutes class FlatlandRemoteEvaluationService: @@ -59,17 +61,18 @@ class FlatlandRemoteEvaluationService: unpacked with `msgpack` (a patched version of msgpack which also supports numpy arrays). """ + def __init__(self, - test_env_folder="/tmp", - flatland_rl_service_id='FLATLAND_RL_SERVICE_ID', - remote_host='127.0.0.1', - remote_port=6379, - remote_db=0, - remote_password=None, - visualize=False, - video_generation_envs=[], - report=None, - verbose=False): + test_env_folder="/tmp", + flatland_rl_service_id='FLATLAND_RL_SERVICE_ID', + remote_host='127.0.0.1', + remote_port=6379, + remote_db=0, + remote_password=None, + visualize=False, + video_generation_envs=[], + report=None, + verbose=False): # Test Env folder Paths self.test_env_folder = test_env_folder @@ -83,15 +86,15 @@ class FlatlandRemoteEvaluationService: # Logging and Reporting related vars self.verbose = verbose self.report = report - + # Communication Protocol Related vars self.namespace = "flatland-rl" self.service_id = flatland_rl_service_id self.command_channel = "{}::{}::commands".format( - self.namespace, - self.service_id - ) - + self.namespace, + self.service_id + ) + # Message Broker related vars self.remote_host = remote_host self.remote_port = remote_port @@ -114,7 +117,7 @@ class FlatlandRemoteEvaluationService: "normalized_reward": 0.0 } } - + # RailEnv specific variables self.env = False self.env_renderer = False @@ -156,7 +159,7 @@ class FlatlandRemoteEvaluationService:   ├── .......   ├── ....... └── Level_99.pkl - """ + """ env_paths = sorted(glob.glob( os.path.join( self.test_env_folder, @@ -179,16 +182,16 @@ class FlatlandRemoteEvaluationService: """ if self.verbose or self.report: print("Attempting to connect to redis server at {}:{}/{}".format( - self.remote_host, - self.remote_port, - self.remote_db)) + self.remote_host, + self.remote_port, + self.remote_db)) self.redis_pool = redis.ConnectionPool( - host=self.remote_host, - port=self.remote_port, - db=self.remote_db, - password=self.remote_password - ) + host=self.remote_host, + port=self.remote_port, + db=self.remote_db, + password=self.remote_password + ) def get_redis_connection(self): """ @@ -200,13 +203,13 @@ class FlatlandRemoteEvaluationService: redis_conn.ping() except Exception as e: raise Exception( - "Unable to connect to redis server at {}:{} ." - "Are you sure there is a redis-server running at the " - "specified location ?".format( - self.remote_host, - self.remote_port - ) - ) + "Unable to connect to redis server at {}:{} ." + "Are you sure there is a redis-server running at the " + "specified location ?".format( + self.remote_host, + self.remote_port + ) + ) return redis_conn def _error_template(self, payload): @@ -220,8 +223,8 @@ class FlatlandRemoteEvaluationService: return _response @timeout_decorator.timeout( - PER_STEP_TIMEOUT, - use_signals=use_signals_in_timeout) # timeout for each command + PER_STEP_TIMEOUT, + use_signals=use_signals_in_timeout) # timeout for each command def _get_next_command(self, _redis): """ A low level wrapper for obtaining the next command from a @@ -231,7 +234,7 @@ class FlatlandRemoteEvaluationService: """ command = _redis.brpop(self.command_channel)[1] return command - + def get_next_command(self): """ A helper function to obtain the next command, which transparently @@ -246,18 +249,18 @@ class FlatlandRemoteEvaluationService: print("Command Service: ", command) except timeout_decorator.timeout_decorator.TimeoutError: raise Exception( - "Timeout in step {} of simulation {}".format( - self.current_step, - self.simulation_count - )) + "Timeout in step {} of simulation {}".format( + self.current_step, + self.simulation_count + )) command = msgpack.unpackb( - command, - object_hook=m.decode, - encoding="utf8" - ) + command, + object_hook=m.decode, + encoding="utf8" + ) if self.verbose: print("Received Request : ", command) - + return command def send_response(self, _command_response, command, suppress_logs=False): @@ -266,15 +269,15 @@ class FlatlandRemoteEvaluationService: if self.verbose and not suppress_logs: print("Responding with : ", _command_response) - + _redis.rpush( - command_response_channel, + command_response_channel, msgpack.packb( - _command_response, - default=m.encode, + _command_response, + default=m.encode, use_bin_type=True) ) - + def handle_ping(self, command): """ Handles PING command from the client. @@ -313,9 +316,9 @@ class FlatlandRemoteEvaluationService: ) if self.visualize: if self.env_renderer: - del self.env_renderer + del self.env_renderer self.env_renderer = RenderTool(self.env, gl="PILSVG", ) - + # Set max episode steps allowed self.env._max_episode_steps = \ int(1.5 * (self.env.width + self.env.height)) @@ -323,7 +326,7 @@ class FlatlandRemoteEvaluationService: if self.begin_simulation: # If begin simulation has already been initialized # atleast once - self.simulation_times.append(time.time()-self.begin_simulation) + self.simulation_times.append(time.time() - self.begin_simulation) self.begin_simulation = time.time() self.simulation_rewards.append(0) @@ -348,15 +351,15 @@ class FlatlandRemoteEvaluationService: _command_response['type'] = messages.FLATLAND_RL.ENV_CREATE_RESPONSE _command_response['payload'] = {} _command_response['payload']['observation'] = False - _command_response['payload']['env_file_path'] = False + _command_response['payload']['env_file_path'] = False self.send_response(_command_response, command) ##################################################################### # Update evaluation state ##################################################################### progress = np.clip( - self.simulation_count * 1.0 / len(self.env_file_paths), - 0, 1) + self.simulation_count * 1.0 / len(self.env_file_paths), + 0, 1) mean_reward = round(np.mean(self.simulation_rewards), 2) mean_normalized_reward = round(np.mean(self.simulation_rewards_normalized), 2) mean_percentage_complete = round(np.mean(self.simulation_percentage_complete), 3) @@ -399,9 +402,9 @@ class FlatlandRemoteEvaluationService: """ self.simulation_rewards_normalized[-1] += \ cumulative_reward / ( - self.env._max_episode_steps + - self.env.get_num_agents() - ) + self.env._max_episode_steps + + self.env.get_num_agents() + ) if done["__all__"]: # Compute percentage complete @@ -412,14 +415,14 @@ class FlatlandRemoteEvaluationService: complete += 1 percentage_complete = complete * 1.0 / self.env.get_num_agents() self.simulation_percentage_complete[-1] = percentage_complete - + # Record Frame if self.visualize: self.env_renderer.render_env( - show=False, - show_observations=False, - show_predictions=False - ) + show=False, + show_observations=False, + show_predictions=False + ) """ Only save the frames for environments which are separately provided in video_generation_indices param @@ -427,10 +430,10 @@ class FlatlandRemoteEvaluationService: current_env_path = self.env_file_paths[self.simulation_count] if current_env_path in self.video_generation_envs: self.env_renderer.gl.save_image( - os.path.join( - self.vizualization_folder_name, - "flatland_frame_{:04d}.png".format(self.record_frame_step) - )) + os.path.join( + self.vizualization_folder_name, + "flatland_frame_{:04d}.png".format(self.record_frame_step) + )) self.record_frame_step += 1 # Build and send response @@ -453,7 +456,7 @@ class FlatlandRemoteEvaluationService: _payload = command['payload'] # Register simulation time of the last episode - self.simulation_times.append(time.time()-self.begin_simulation) + self.simulation_times.append(time.time() - self.begin_simulation) if len(self.simulation_rewards) != len(self.env_file_paths): raise Exception( @@ -461,7 +464,7 @@ class FlatlandRemoteEvaluationService: to operate on all the test environments. """ ) - + mean_reward = round(np.mean(self.simulation_rewards), 2) mean_normalized_reward = round(np.mean(self.simulation_rewards_normalized), 2) mean_percentage_complete = round(np.mean(self.simulation_percentage_complete), 3) @@ -473,7 +476,7 @@ class FlatlandRemoteEvaluationService: # install it by : # # conda install -c conda-forge x264 ffmpeg - + print("Generating Video from thumbnails...") video_output_path, video_thumb_output_path = \ aicrowd_helpers.generate_movie_from_frames( @@ -518,14 +521,14 @@ class FlatlandRemoteEvaluationService: self.evaluation_state["score"]["score_secondary"] = mean_reward self.evaluation_state["meta"]["normalized_reward"] = mean_normalized_reward self.handle_aicrowd_success_event(self.evaluation_state) - print("#"*100) + print("#" * 100) print("EVALUATION COMPLETE !!") - print("#"*100) + print("#" * 100) print("# Mean Reward : {}".format(mean_reward)) print("# Mean Normalized Reward : {}".format(mean_normalized_reward)) print("# Mean Percentage Complete : {}".format(mean_percentage_complete)) - print("#"*100) - print("#"*100) + print("#" * 100) + print("#" * 100) def report_error(self, error_message, command_response_channel): """ @@ -536,16 +539,16 @@ class FlatlandRemoteEvaluationService: _command_response['type'] = messages.FLATLAND_RL.ERROR _command_response['payload'] = error_message _redis.rpush( - command_response_channel, + command_response_channel, msgpack.packb( - _command_response, - default=m.encode, + _command_response, + default=m.encode, use_bin_type=True) - ) + ) self.evaluation_state["state"] = "ERROR" self.evaluation_state["error"] = error_message self.handle_aicrowd_error_event(self.evaluation_state) - + def handle_aicrowd_info_event(self, payload): self.oracle_events.register_event( event_type=self.oracle_events.CROWDAI_EVENT_INFO, @@ -577,17 +580,17 @@ class FlatlandRemoteEvaluationService: print("Self.Reward : ", self.reward) print("Current Simulation : ", self.simulation_count) if self.env_file_paths and \ - self.simulation_count < len(self.env_file_paths): + self.simulation_count < len(self.env_file_paths): print("Current Env Path : ", - self.env_file_paths[self.simulation_count]) + self.env_file_paths[self.simulation_count]) - try: + try: if command['type'] == messages.FLATLAND_RL.PING: """ INITIAL HANDSHAKE : Respond with PONG """ self.handle_ping(command) - + elif command['type'] == messages.FLATLAND_RL.ENV_CREATE: """ ENV_CREATE @@ -612,8 +615,8 @@ class FlatlandRemoteEvaluationService: self.handle_env_submit(command) else: _error = self._error_template( - "UNKNOWN_REQUEST:{}".format( - str(command))) + "UNKNOWN_REQUEST:{}".format( + str(command))) if self.verbose: print("Responding with : ", _error) self.report_error( @@ -631,10 +634,11 @@ class FlatlandRemoteEvaluationService: if __name__ == "__main__": import argparse + parser = argparse.ArgumentParser(description='Submit the result to AIcrowd') - parser.add_argument('--service_id', - dest='service_id', - default='FLATLAND_RL_SERVICE_ID', + parser.add_argument('--service_id', + dest='service_id', + default='FLATLAND_RL_SERVICE_ID', required=False) parser.add_argument('--test_folder', dest='test_folder', @@ -642,16 +646,16 @@ if __name__ == "__main__": help="Folder containing the files for the test envs", required=False) args = parser.parse_args() - + test_folder = args.test_folder grader = FlatlandRemoteEvaluationService( - test_env_folder=test_folder, - flatland_rl_service_id=args.service_id, - verbose=True, - visualize=True, - video_generation_envs=["Test_0/Level_1.pkl"] - ) + test_env_folder=test_folder, + flatland_rl_service_id=args.service_id, + verbose=True, + visualize=True, + video_generation_envs=["Test_0/Level_1.pkl"] + ) result = grader.run() if result['type'] == messages.FLATLAND_RL.ENV_SUBMIT_RESPONSE: cumulative_results = result['payload'] diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index 69be59ae2a957f6a2aaa948d9830472d7824516a..af1aad222919b00b716dd9da0f3be9534d54e411 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -11,9 +11,9 @@ from numpy import array import flatland.utils.rendertools as rt from flatland.core.grid.grid4_utils import mirror from flatland.envs.agent_utils import EnvAgent, EnvAgentStatic -from flatland.envs.generators import complex_rail_generator, empty_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv, random_rail_generator +from flatland.envs.rail_generators import complex_rail_generator, empty_rail_generator class EditorMVC(object): diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index 6a0a9282614c0319338454f5b8ae97531b12e432..92a0f84f35fa942b03236c6add6e722475a2d842 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -172,7 +172,7 @@ class PILGL(GraphicsLayer): def text(self, xPx, yPx, strText, layer=RAIL_LAYER): xyPixLeftTop = (xPx, yPx) self.draws[layer].text(xyPixLeftTop, strText, font=self.font, fill=(0, 0, 0, 255)) - + def text_rowcol(self, rcTopLeft, strText, layer=AGENT_LAYER): print("Text:", "rc:", rcTopLeft, "text:", strText, "layer:", layer) xyPixLeftTop = tuple((array(rcTopLeft) * self.nPixCell)[[1, 0]]) @@ -500,9 +500,9 @@ class PILSVG(PILGL): False)[0] self.draw_image_row_col(colored_rail, (row, col), layer=PILGL.PREDICTION_PATH_LAYER) - def set_rail_at(self, row, col, binary_trans, target=None, is_selected=False, rail_grid=None, - show_debug=True): - + def set_rail_at(self, row, col, binary_trans, target=None, is_selected=False, rail_grid=None, + show_debug=True): + if binary_trans in self.pil_rail: pil_track = self.pil_rail[binary_trans] if target is not None: @@ -510,7 +510,7 @@ class PILSVG(PILGL): target_img = Image.alpha_composite(pil_track, target_img) self.draw_image_row_col(target_img, (row, col), layer=PILGL.TARGET_LAYER) if show_debug: - self.text_rowcol((row+0.8, col+0.0), strText=str(target), layer=PILGL.TARGET_LAYER) + self.text_rowcol((row + 0.8, col + 0.0), strText=str(target), layer=PILGL.TARGET_LAYER) if binary_trans == 0: if self.background_grid[col][row] <= 4: @@ -607,7 +607,7 @@ class PILSVG(PILGL): if show_debug: print("Call text:") - self.text_rowcol((row+0.2, col+0.2,), str(agent_idx)) + self.text_rowcol((row + 0.2, col + 0.2,), str(agent_idx)) def set_cell_occupied(self, agent_idx, row, col): occupied_im = self.cell_occupied[agent_idx % len(self.cell_occupied)] diff --git a/notebooks/simple_example1_env_from_tuple.ipynb b/notebooks/simple_example1_env_from_tuple.ipynb index 3fd55bc8fafabdd57eb43e012c5d98b37c73d496..0fcfe26325e5778cbbfdf66591346972edb2406d 100644 --- a/notebooks/simple_example1_env_from_tuple.ipynb +++ b/notebooks/simple_example1_env_from_tuple.ipynb @@ -14,7 +14,7 @@ "metadata": {}, "outputs": [], "source": [ - "from flatland.envs.generators import rail_from_manual_specifications_generator\n", + "from flatland.envs.rail_generators import rail_from_manual_specifications_generator\n", "from flatland.envs.observations import TreeObsForRailEnv\n", "from flatland.envs.rail_env import RailEnv\n", "from flatland.utils.rendertools import RenderTool\n", diff --git a/notebooks/simple_example2_generate_random_rail.ipynb b/notebooks/simple_example2_generate_random_rail.ipynb index b9d4a96c02d4cb63392654025c6857c8b6764d1a..19b854ee15d8dd1e19361f58a552eba617b19b67 100644 --- a/notebooks/simple_example2_generate_random_rail.ipynb +++ b/notebooks/simple_example2_generate_random_rail.ipynb @@ -15,7 +15,7 @@ "source": [ "import random\n", "import numpy as np\n", - "from flatland.envs.generators import random_rail_generator\n", + "from flatland.envs.rail_generators import random_rail_generator\n", "from flatland.envs.observations import TreeObsForRailEnv\n", "from flatland.envs.rail_env import RailEnv\n", "from flatland.utils.rendertools import RenderTool\n", diff --git a/notebooks/simple_example_3_manual_control.ipynb b/notebooks/simple_example_3_manual_control.ipynb index 50f228055b320c15e0411c2b086254a1f4d4ceef..cb2b377765f375e8d77d8374fdcc3bb67ce06444 100644 --- a/notebooks/simple_example_3_manual_control.ipynb +++ b/notebooks/simple_example_3_manual_control.ipynb @@ -40,7 +40,7 @@ "import random\n", "import numpy as np\n", "import time\n", - "from flatland.envs.generators import random_rail_generator\n", + "from flatland.envs.rail_generators import random_rail_generator\n", "from flatland.envs.observations import TreeObsForRailEnv\n", "from flatland.envs.rail_env import RailEnv\n", "from flatland.utils.rendertools import RenderTool" diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py index 62df397d34b06755465ca0c9f664b9117c87243f..e5e89f76428bb881d0f72aa60aada97ab02167a5 100644 --- a/tests/test_distance_map.py +++ b/tests/test_distance_map.py @@ -2,10 +2,11 @@ import numpy as np from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap -from flatland.envs.generators import rail_from_grid_transition_map from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.schedule_generators import random_schedule_generator def test_walker(): @@ -27,6 +28,7 @@ def test_walker(): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(max_depth=10)), diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index 574705c49501415f37149bd4b3d870665bf06e60..c96e8db00fe721f42667aed4833d034a47f19156 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -5,10 +5,11 @@ import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.agent_utils import EnvAgent -from flatland.envs.generators import rail_from_grid_transition_map from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions +from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.schedule_generators import random_schedule_generator from flatland.utils.rendertools import RenderTool from flatland.utils.simple_rail import make_simple_rail @@ -21,6 +22,7 @@ def test_global_obs(): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) @@ -90,6 +92,7 @@ def test_reward_function_conflict(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) @@ -168,6 +171,7 @@ def test_reward_function_waiting(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index 98c276f894b51685ce0edf43f6bd1b1137d46eb0..09f7e5e67a15c55b5070ac8679e43ecc9a14b9da 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -5,10 +5,11 @@ import pprint import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.envs.generators import rail_from_grid_transition_map from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.schedule_generators import random_schedule_generator from flatland.utils.rendertools import RenderTool from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail @@ -21,6 +22,7 @@ def test_dummy_predictor(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)), ) @@ -111,6 +113,7 @@ def test_shortest_path_predictor(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) @@ -230,6 +233,7 @@ def test_shortest_path_predictor_conflicts(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index 7ebbbb1461e24aae4c2319f51a9bb4abb2d3b25c..d5dc3ac7af4be6ebd8c5cbeaf705bb710d36d138 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -6,10 +6,11 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgent from flatland.envs.agent_utils import EnvAgentStatic -from flatland.envs.generators import complex_rail_generator -from flatland.envs.generators import rail_from_grid_transition_map from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import complex_rail_generator +from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator """Tests for `flatland` package.""" @@ -26,6 +27,7 @@ def test_load_env(): def test_save_load(): env = RailEnv(width=10, height=10, rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=0), + schedule_generator=complex_schedule_generator(), number_of_agents=2) env.reset() agent_1_pos = env.agents_static[0].position @@ -77,6 +79,7 @@ def test_rail_environment_single_agent(): rail_env = RailEnv(width=3, height=3, rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) @@ -156,6 +159,7 @@ def test_dead_end(): rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) @@ -200,6 +204,7 @@ def test_dead_end(): rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py similarity index 87% rename from tests/test_flatland_env_sparse_rail_generator.py rename to tests/test_flatland_envs_sparse_rail_generator.py index d59e684575e9410b2859bb011ecb835a267b1c36..db7cac61f4cf3bec4a330694c1864ef7d82bd076 100644 --- a/tests/test_flatland_env_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -1,6 +1,7 @@ -from flatland.envs.generators import sparse_rail_generator from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import sparse_rail_generator +from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.utils.rendertools import RenderTool @@ -16,6 +17,7 @@ def test_sparse_rail_generator(): seed=5, # Random seed realistic_mode=False # Ordered distribution of nodes ), + schedule_generator=sparse_schedule_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) # reset to initialize agents_static diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 67dcd25c0769e542fd9a03502c2a8c1b29333b2b..eaf782df3255ecfc6ebaa7078935f485497ed359 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -1,8 +1,9 @@ import numpy as np -from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import complex_rail_generator +from flatland.envs.schedule_generators import complex_schedule_generator class SingleAgentNavigationObs(TreeObsForRailEnv): @@ -62,6 +63,7 @@ def test_malfunction_process(): height=20, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0), + schedule_generator=complex_schedule_generator(), number_of_agents=2, obs_builder_object=SingleAgentNavigationObs(), stochastic_data=stochastic_data) diff --git a/tests/test_flatland_utils_rendertools.py b/tests/test_flatland_utils_rendertools.py index ac5d7f4132b5c224206af3602b6c1341fe026d8b..8248c675995fc5c906e82d8650a5b619e7b038f2 100644 --- a/tests/test_flatland_utils_rendertools.py +++ b/tests/test_flatland_utils_rendertools.py @@ -11,9 +11,9 @@ from importlib_resources import path import flatland.utils.rendertools as rt import images.test -from flatland.envs.generators import empty_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import empty_rail_generator def checkFrozenImage(oRT, sFileImage, resave=False): diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index 47aadee73cfc0b45dc701fe914a586feb31b2597..8de36c81e4a13c0b7e7e5e556ad79234503ad31a 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -1,10 +1,12 @@ import numpy as np -from flatland.envs.generators import complex_rail_generator from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import complex_rail_generator +from flatland.envs.schedule_generators import complex_schedule_generator np.random.seed(1) + # Use the complex_rail_generator to generate feasible network configurations with corresponding tasks # Training on simple small tasks is the best way to get familiar with the environment # @@ -46,6 +48,7 @@ def test_multi_speed_init(): height=50, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), + schedule_generator=complex_schedule_generator(), number_of_agents=5) # Initialize the agent with the parameters corresponding to the environment and observation_builder agent = RandomAgent(218, 4) diff --git a/tests/test_speed_classes.py b/tests/test_speed_classes.py new file mode 100644 index 0000000000000000000000000000000000000000..ff5ee56a308ce19559d079b716bde90ad65baf11 --- /dev/null +++ b/tests/test_speed_classes.py @@ -0,0 +1,36 @@ +"""Test speed initialization by a map of speeds and their corresponding ratios.""" +import numpy as np + +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import complex_rail_generator +from flatland.envs.schedule_generators import speed_initialization_helper, complex_schedule_generator + + +def test_speed_initialization_helper(): + np.random.seed(1) + speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.3} + actual_speeds = speed_initialization_helper(10, speed_ratio_map) + + # seed makes speed_initialization_helper deterministic -> check generated speeds. + assert actual_speeds == [2, 3, 1, 2, 1, 1, 1, 2, 2, 2] + + +def test_rail_env_speed_intializer(): + speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.1, 5: 0.2} + + env = RailEnv(width=50, + height=50, + rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, + seed=0), + schedule_generator=complex_schedule_generator(), + number_of_agents=10) + env.reset() + actual_speeds = list(map(lambda agent: agent.speed_data['speed'], env.agents)) + + expected_speed_set = set(speed_ratio_map.keys()) + + # check that the number of speeds generated is correct + assert len(actual_speeds) == env.get_num_agents() + + # check that only the speeds defined are generated + assert all({(actual_speed in expected_speed_set) for actual_speed in actual_speeds}) diff --git a/tests/tests_generators.py b/tests/tests_generators.py index f97b071e6b33c099efa5af36766e159e57716443..610022cafe12fccb2cbbd5da57006e61c89faf28 100644 --- a/tests/tests_generators.py +++ b/tests/tests_generators.py @@ -3,11 +3,13 @@ import numpy as np -from flatland.envs.generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \ - random_rail_generator, empty_rail_generator from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \ + random_rail_generator, empty_rail_generator +from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator, \ + agents_from_file from flatland.utils.simple_rail import make_simple_rail @@ -58,7 +60,8 @@ def test_complex_rail_generator(): env = RailEnv(width=x_dim, height=y_dim, number_of_agents=n_agents, - rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist) + rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist), + schedule_generator=complex_schedule_generator() ) assert env.get_num_agents() == 2 assert env.rail.grid.shape == (y_dim, x_dim) @@ -69,7 +72,8 @@ def test_complex_rail_generator(): env = RailEnv(width=x_dim, height=y_dim, number_of_agents=n_agents, - rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist) + rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist), + schedule_generator=complex_schedule_generator() ) assert env.get_num_agents() == 0 assert env.rail.grid.shape == (y_dim, x_dim) @@ -82,7 +86,8 @@ def test_complex_rail_generator(): env = RailEnv(width=x_dim, height=y_dim, number_of_agents=n_agents, - rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist) + rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist), + schedule_generator=complex_schedule_generator() ) assert env.get_num_agents() == n_agents assert env.rail.grid.shape == (y_dim, x_dim) @@ -94,6 +99,7 @@ def test_rail_from_grid_transition_map(): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=n_agents ) nr_rail_elements = np.count_nonzero(env.rail.grid) @@ -118,6 +124,7 @@ def tests_rail_from_file(): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=3, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) @@ -130,6 +137,7 @@ def tests_rail_from_file(): env = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), + schedule_generator=agents_from_file(file_name), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) @@ -151,6 +159,7 @@ def tests_rail_from_file(): env2 = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=3, obs_builder_object=GlobalObsForRailEnv(), ) @@ -164,6 +173,7 @@ def tests_rail_from_file(): env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name_2), + schedule_generator=agents_from_file(file_name_2), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv(), ) @@ -180,6 +190,7 @@ def tests_rail_from_file(): env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), + schedule_generator=agents_from_file(file_name), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv(), ) @@ -197,6 +208,7 @@ def tests_rail_from_file(): env4 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name_2), + schedule_generator=agents_from_file(file_name_2), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2), )