diff --git a/docs/flatland.rst b/docs/flatland.rst index 88e8ec93fd4f6c89c1f6e20c55defbddaa9b28fa..e09087a49b6df3572ac38b77e41ca739bcea8150 100644 --- a/docs/flatland.rst +++ b/docs/flatland.rst @@ -6,10 +6,10 @@ Subpackages .. toctree:: - flatland.core - flatland.envs - flatland.evaluators - flatland.utils + flatland.core + flatland.envs + flatland.evaluators + flatland.utils Submodules ---------- @@ -18,15 +18,15 @@ flatland.cli module ------------------- .. automodule:: flatland.cli - :members: - :undoc-members: - :show-inheritance: + :members: + :undoc-members: + :show-inheritance: Module contents --------------- .. automodule:: flatland - :members: - :undoc-members: - :show-inheritance: + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/complex_rail_benchmark.py b/examples/complex_rail_benchmark.py index 44e4b534c2f2dfa63cab385d009c9afd92285f48..49e550b195ffb15e6554413069369378e80e5f82 100644 --- a/examples/complex_rail_benchmark.py +++ b/examples/complex_rail_benchmark.py @@ -3,8 +3,9 @@ import random import numpy as np -from flatland.envs.generators import complex_rail_generator from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import complex_rail_generator +from flatland.envs.schedule_generators import complex_schedule_generator def run_benchmark(): @@ -15,6 +16,7 @@ def run_benchmark(): # Example generate a random rail env = RailEnv(width=15, height=15, rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12), + schedule_generator=complex_schedule_generator(), number_of_agents=5) n_trials = 20 diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py index 723bb1102092c7d48bd938bbd60d0c5213ffecf6..8b1de6aa4e303469d30983d30333fbfda89c1d1e 100644 --- a/examples/custom_observation_example.py +++ b/examples/custom_observation_example.py @@ -5,10 +5,11 @@ import numpy as np from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid_utils import coordinate_to_position -from flatland.envs.generators import random_rail_generator, complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import random_rail_generator, complex_rail_generator +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool random.seed(100) @@ -20,6 +21,7 @@ class SimpleObs(ObservationBuilder): Simplest observation builder. The object returns observation vectors with 5 identical components, all equal to the ID of the respective agent. """ + def __init__(self): self.observation_space = [5] @@ -53,6 +55,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector will be [1, 0, 0]. """ + def __init__(self): super().__init__(max_depth=0) self.observation_space = [3] @@ -90,6 +93,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): env = RailEnv(width=7, height=7, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0), + schedule_generator=complex_schedule_generator(), number_of_agents=1, obs_builder_object=SingleAgentNavigationObs()) @@ -97,8 +101,8 @@ obs = env.reset() env_renderer = RenderTool(env, gl="PILSVG") env_renderer.render_env(show=True, frames=True, show_observations=True) for step in range(100): - action = np.argmax(obs[0])+1 - obs, all_rewards, done, _ = env.step({0:action}) + action = np.argmax(obs[0]) + 1 + obs, all_rewards, done, _ = env.step({0: action}) print("Rewards: ", all_rewards, " [done=", done, "]") env_renderer.render_env(show=True, frames=True, show_observations=True) time.sleep(0.1) @@ -200,6 +204,7 @@ CustomObsBuilder = ObservePredictions(CustomPredictor) env = RailEnv(width=10, height=10, rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, seed=0), + schedule_generator=complex_schedule_generator(), number_of_agents=3, obs_builder_object=CustomObsBuilder) diff --git a/examples/custom_railmap_example.py b/examples/custom_railmap_example.py index 515d6c1b0469b7fbd9bad8cd82a40db7766f6219..6162b918734eb311752675e75e203d90e5558c1c 100644 --- a/examples/custom_railmap_example.py +++ b/examples/custom_railmap_example.py @@ -1,30 +1,41 @@ import random +from typing import Any import numpy as np from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import RailGenerator, RailGeneratorProduct +from flatland.envs.schedule_generators import ScheduleGenerator, ScheduleGeneratorProduct from flatland.utils.rendertools import RenderTool random.seed(100) np.random.seed(100) -def custom_rail_generator(): - def generator(width, height, num_agents=0, num_resets=0): +def custom_rail_generator() -> RailGenerator: + def generator(width: int, height: int, num_agents: int = 0, num_resets: int = 0) -> RailGeneratorProduct: rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) rail_array = grid_map.grid rail_array.fill(0) new_tran = rail_trans.set_transition(1, 1, 1, 1) print(new_tran) + rail_array[0, 0] = new_tran + rail_array[0, 1] = new_tran + return grid_map, None + + return generator + + +def custom_schedule_generator() -> ScheduleGenerator: + def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct: agents_positions = [] agents_direction = [] agents_target = [] - rail_array[0, 0] = new_tran - rail_array[0, 1] = new_tran - return grid_map, agents_positions, agents_direction, agents_target + speeds = [] + return agents_positions, agents_direction, agents_target, speeds return generator @@ -32,6 +43,7 @@ def custom_rail_generator(): env = RailEnv(width=6, height=4, rail_generator=custom_rail_generator(), + schedule_generator=custom_schedule_generator(), number_of_agents=1) env.reset() diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py index 2c0f814576caef84471d20c91dd92d23d4db02ac..50ea74b84ac9851e88e48bcd32b914e69bc7dd34 100644 --- a/examples/debugging_example_DELETE.py +++ b/examples/debugging_example_DELETE.py @@ -3,14 +3,16 @@ import time import numpy as np -from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import complex_rail_generator +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool random.seed(1) np.random.seed(1) + class SingleAgentNavigationObs(TreeObsForRailEnv): """ We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute @@ -21,6 +23,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector will be [1, 0, 0]. """ + def __init__(self): super().__init__(max_depth=0) self.observation_space = [3] @@ -58,6 +61,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): env = RailEnv(width=14, height=14, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0), + schedule_generator=complex_schedule_generator(), number_of_agents=2, obs_builder_object=SingleAgentNavigationObs()) @@ -67,11 +71,11 @@ env_renderer.render_env(show=True, frames=True, show_observations=False) for step in range(100): actions = {} for i in range(len(obs)): - actions[i] = np.argmax(obs[i])+1 + actions[i] = np.argmax(obs[i]) + 1 - if step%5 == 0: + if step % 5 == 0: print("Agent halts") - actions[0] = 4 # Halt + actions[0] = 4 # Halt obs, all_rewards, done, _ = env.step(actions) if env.agents[0].malfunction_data['malfunction'] > 0: @@ -82,4 +86,3 @@ for step in range(100): if done["__all__"]: break env_renderer.close_window() - diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py new file mode 100644 index 0000000000000000000000000000000000000000..71a185c765bcab831e7b104124a164bcf2398b14 --- /dev/null +++ b/examples/flatland_2_0_example.py @@ -0,0 +1,120 @@ +import numpy as np +from flatland.envs.rail_generators import sparse_rail_generator + +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.rail_env import RailEnv +from flatland.envs.schedule_generators import sparse_schedule_generator +from flatland.utils.rendertools import RenderTool + +np.random.seed(1) + +# Use the new sparse_rail_generator to generate feasible network configurations with corresponding tasks +# Training on simple small tasks is the best way to get familiar with the environment + +# Use a the malfunction generator to break agents from time to time +stochastic_data = {'prop_malfunction': 0.5, # Percentage of defective agents + 'malfunction_rate': 30, # Rate of malfunction occurence + 'min_duration': 3, # Minimal duration of malfunction + 'max_duration': 10 # Max duration of malfunction + } + +TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) +env = RailEnv(width=20, + height=20, + rail_generator=sparse_rail_generator(num_cities=2, # Number of cities in map (where train stations are) + num_intersections=1, # Number of intersections (no start / target) + num_trainstations=15, # Number of possible start/targets on map + min_node_dist=3, # Minimal distance of nodes + node_radius=3, # Proximity of stations to city center + num_neighb=2, # Number of connections to other cities/intersections + seed=15, # Random seed + realistic_mode=True, + enhance_intersection=True + ), + schedule_generator=sparse_schedule_generator(), + number_of_agents=5, + stochastic_data=stochastic_data, # Malfunction data generator + obs_builder_object=TreeObservation) + +env_renderer = RenderTool(env, gl="PILSVG", ) + + +# Import your own Agent or use RLlib to train agents on Flatland +# As an example we use a random agent instead +class RandomAgent: + + def __init__(self, state_size, action_size): + self.state_size = state_size + self.action_size = action_size + + def act(self, state): + """ + :param state: input is the observation of the agent + :return: returns an action + """ + return np.random.choice(np.arange(self.action_size)) + + def step(self, memories): + """ + Step function to improve agent by adjusting policy given the observations + + :param memories: SARS Tuple to be + :return: + """ + return + + def save(self, filename): + # Store the current policy + return + + def load(self, filename): + # Load a policy + return + + +# Initialize the agent with the parameters corresponding to the environment and observation_builder +# Set action space to 4 to remove stop action +agent = RandomAgent(218, 4) + +# Empty dictionary for all agent action +action_dict = dict() + +print("Start episode...") +# Reset environment and get initial observations for all agents +obs = env.reset() +# Update/Set agent's speed +for idx in range(env.get_num_agents()): + speed = 1.0 / ((idx % 5) + 1.0) + env.agents[idx].speed_data["speed"] = speed + +# Reset the rendering sytem +env_renderer.reset() + +# Here you can also further enhance the provided observation by means of normalization +# See training navigation example in the baseline repository + +score = 0 +# Run episode +frame_step = 0 +for step in range(500): + # Chose an action for each agent in the environment + for a in range(env.get_num_agents()): + action = agent.act(obs[a]) + action_dict.update({a: action}) + + # Environment step which returns the observations for all agents, their corresponding + # reward and whether their are done + next_obs, all_rewards, done, _ = env.step(action_dict) + env_renderer.render_env(show=True, show_observations=False, show_predictions=False) + frame_step += 1 + # Update replay buffer and train agent + for a in range(env.get_num_agents()): + agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])) + score += all_rewards[a] + + obs = next_obs.copy() + if done['__all__']: + break + +print('Episode: Steps {}\t Score = {}'.format(step, score)) diff --git a/examples/simple_example_1.py b/examples/simple_example_1.py index 7956c34fd4a5b94859a4b64441450afe2114133c..fbadbd657c36fa1dadf0bca65cff3e9cccd269ea 100644 --- a/examples/simple_example_1.py +++ b/examples/simple_example_1.py @@ -1,5 +1,5 @@ -from flatland.envs.generators import rail_from_manual_specifications_generator from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import rail_from_manual_specifications_generator from flatland.utils.rendertools import RenderTool # Example generate a rail given a manual specification, diff --git a/examples/simple_example_2.py b/examples/simple_example_2.py index 994c7deda1569b77d4adac8a17fa9ebe14b27ef6..6db9ba5abbd0999ef3896e733516ed6b3e498bae 100644 --- a/examples/simple_example_2.py +++ b/examples/simple_example_2.py @@ -2,8 +2,8 @@ import random import numpy as np -from flatland.envs.generators import random_rail_generator from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import random_rail_generator from flatland.utils.rendertools import RenderTool random.seed(100) diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py index 5aa03d8f95a7079b708baea1e2ddce27e9a46554..6df6d4af3076b3d9659aadbd55296b667dc7d6db 100644 --- a/examples/simple_example_3.py +++ b/examples/simple_example_3.py @@ -2,9 +2,10 @@ import random import numpy as np -from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import complex_rail_generator +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool random.seed(1) @@ -13,6 +14,7 @@ np.random.seed(1) env = RailEnv(width=7, height=7, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), + schedule_generator=complex_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2)) diff --git a/examples/training_example.py b/examples/training_example.py index d125be1587a56025ba1cd3f78b28ba3976f01fbf..df93479f5a5ee05abfcb1a98b07ef052bffc2bd4 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -1,9 +1,10 @@ import numpy as np -from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv, LocalObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import complex_rail_generator +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool np.random.seed(1) @@ -16,11 +17,13 @@ LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2) env = RailEnv(width=20, height=20, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), + schedule_generator=complex_schedule_generator(), obs_builder_object=TreeObservation, number_of_agents=3) env_renderer = RenderTool(env, gl="PILSVG", ) + # Import your own Agent or use RLlib to train agents on Flatland # As an example we use a random agent here diff --git a/flatland/cli.py b/flatland/cli.py index 32e8d9dc786b0412795694fc985c90aa55fc2e91..47c450dba803fac17bc13979663ef04e4c0db899 100644 --- a/flatland/cli.py +++ b/flatland/cli.py @@ -2,29 +2,33 @@ """Console script for flatland.""" import sys +import time + import click import numpy as np -import time -from flatland.envs.generators import complex_rail_generator +import redis + from flatland.envs.rail_env import RailEnv -from flatland.utils.rendertools import RenderTool +from flatland.envs.rail_generators import complex_rail_generator +from flatland.envs.schedule_generators import complex_schedule_generator from flatland.evaluators.service import FlatlandRemoteEvaluationService -import redis +from flatland.utils.rendertools import RenderTool @click.command() def demo(args=None): """Demo script to check installation""" env = RailEnv( - width=15, - height=15, - rail_generator=complex_rail_generator( - nr_start_goal=10, - nr_extra=1, - min_dist=8, - max_dist=99999), - number_of_agents=5) - + width=15, + height=15, + rail_generator=complex_rail_generator( + nr_start_goal=10, + nr_extra=1, + min_dist=8, + max_dist=99999), + schedule_generator=complex_schedule_generator(), + number_of_agents=5) + env._max_episode_steps = int(15 * (env.width + env.height)) env_renderer = RenderTool(env) @@ -52,12 +56,12 @@ def demo(args=None): @click.command() -@click.option('--tests', +@click.option('--tests', type=click.Path(exists=True), help="Path to folder containing Flatland tests", required=True ) -@click.option('--service_id', +@click.option('--service_id', default="FLATLAND_RL_SERVICE_ID", help="Evaluation Service ID. This has to match the service id on the client.", required=False @@ -70,14 +74,14 @@ def evaluator(tests, service_id): raise Exception( "\nRedis server does not seem to be running on your localhost.\n" "Please ensure that you have a redis server running on your localhost" - ) - + ) + grader = FlatlandRemoteEvaluationService( - test_env_folder=tests, - flatland_rl_service_id=service_id, - visualize=False, - verbose=False - ) + test_env_folder=tests, + flatland_rl_service_id=service_id, + visualize=False, + verbose=False + ) grader.run() diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 5e0f6cd72e8ca22c80e3576798e5214f8b036558..bb954998688772a7ce69e5228cff3e16d037f2af 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -7,6 +7,7 @@ from importlib_resources import path from numpy import array from flatland.core.grid.grid4 import Grid4Transitions +from flatland.core.grid.grid4_utils import get_new_position from flatland.core.transitions import Transitions @@ -298,6 +299,76 @@ class GridTransitionMap(TransitionMap): self.height = new_height self.grid = new_grid + def is_dead_end(self, rcPos): + """ + Check if the cell is a dead-end. + + Parameters + ---------- + rcPos: Tuple[int,int] + tuple(row, column) with grid coordinate + Returns + ------- + boolean + True if and only if the cell is a dead-end. + """ + nbits = 0 + tmp = self.get_full_transitions(rcPos[0], rcPos[1]) + while tmp > 0: + nbits += (tmp & 1) + tmp = tmp >> 1 + return nbits == 1 + + def is_simple_turn(self, rcPos): + """ + Check if the cell is a left/right simple turn + + Parameters + ---------- + rcPos: Tuple[int,int] + tuple(row, column) with grid coordinate + Returns + ------- + boolean + True if and only if the cell is a left/right simple turn. + """ + tmp = self.get_full_transitions(rcPos[0], rcPos[1]) + + def is_simple_turn(trans): + all_simple_turns = set() + for trans in [int('0100000000000010', 2), # Case 1b (8) - simple turn right + int('0001001000000000', 2) # Case 1c (9) - simple turn left]: + ]: + for _ in range(3): + trans = self.transitions.rotate_transition(trans, rotation=90) + all_simple_turns.add(trans) + return trans in all_simple_turns + + return is_simple_turn(tmp) + + def check_path_exists(self, start, direction, end): + # print("_path_exists({},{},{}".format(start, direction, end)) + # BFS - Check if a path exists between the 2 nodes + + visited = set() + stack = [(start, direction)] + while stack: + node = stack.pop() + node_position = node[0] + node_direction = node[1] + if node_position[0] == end[0] and node_position[1] == end[1]: + return True + if node not in visited: + visited.add(node) + + moves = self.get_transitions(node_position[0], node_position[1], node_direction) + for move_index in range(4): + if moves[move_index]: + stack.append((get_new_position(node_position, move_index), + move_index)) + + return False + def cell_neighbours_valid(self, rcPos, check_this_cell=False): """ Check validity of cell at rcPos = tuple(row, column) @@ -350,4 +421,124 @@ class GridTransitionMap(TransitionMap): return True + def fix_neighbours(self, rcPos, check_this_cell=False): + """ + Check validity of cell at rcPos = tuple(row, column) + Checks that: + - surrounding cells have inbound transitions for all the + outbound transitions of this cell. + + These are NOT checked - see transition.is_valid: + - all transitions have the mirror transitions (N->E <=> W->S) + - Reverse transitions (N -> S) only exist for a dead-end + - a cell contains either no dead-ends or exactly one + + Returns: True (valid) or False (invalid) + """ + cell_transition = self.grid[tuple(rcPos)] + + if check_this_cell: + if not self.transitions.is_valid(cell_transition): + return False + + gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc] + grcPos = array(rcPos) + grcMax = self.grid.shape + + binTrans = self.get_full_transitions(*rcPos) # 16bit integer - all trans in/out + lnBinTrans = array([binTrans >> 8, binTrans & 0xff], dtype=np.uint8) # 2 x uint8 + g2binTrans = np.unpackbits(lnBinTrans).reshape(4, 4) # 4x4 x uint8 binary(0,1) + gDirOut = g2binTrans.any(axis=0) # outbound directions as boolean array (4) + giDirOut = np.argwhere(gDirOut)[:, 0] # valid outbound directions as array of int + + # loop over available outbound directions (indices) for rcPos + for iDirOut in giDirOut: + gdRC = gDir2dRC[iDirOut] # row,col increment + gPos2 = grcPos + gdRC # next cell in that direction + + # Check the adjacent cell is within bounds + # if not, then this transition is invalid! + if np.any(gPos2 < 0): + return False + if np.any(gPos2 >= grcMax): + return False + + # Get the transitions out of gPos2, using iDirOut as the inbound direction + # if there are no available transitions, ie (0,0,0,0), then rcPos is invalid + t4Trans2 = self.get_transitions(*gPos2, iDirOut) + if any(t4Trans2): + continue + else: + self.set_transition((gPos2[0], gPos2[1], iDirOut), mirror(iDirOut), 1) + return False + + return True + + def fix_transitions(self, rcPos): + """ + Fixes broken transitions + """ + gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc] + grcPos = array(rcPos) + grcMax = self.grid.shape + + # loop over available outbound directions (indices) for rcPos + self.set_transitions(rcPos, 0) + + incoming_connections = np.zeros(4) + for iDirOut in np.arange(4): + gdRC = gDir2dRC[iDirOut] # row,col increment + gPos2 = grcPos + gdRC # next cell in that direction + + # Check the adjacent cell is within bounds + # if not, then ignore it for the count of incoming connections + if np.any(gPos2 < 0): + continue + if np.any(gPos2 >= grcMax): + continue + + # Get the transitions out of gPos2, using iDirOut as the inbound direction + # if there are no available transitions, ie (0,0,0,0), then rcPos is invalid + connected = 0 + for orientation in range(4): + connected += self.get_transition((gPos2[0], gPos2[1], orientation), mirror(iDirOut)) + if connected > 0: + incoming_connections[iDirOut] = 1 + + number_of_incoming = np.sum(incoming_connections) + # Only one incoming direction --> Straight line + if number_of_incoming == 1: + for direction in range(4): + if incoming_connections[direction] > 0: + self.set_transition((rcPos[0], rcPos[1], mirror(direction)), direction, 1) + # Connect all incoming connections + if number_of_incoming == 2: + connect_directions = np.argwhere(incoming_connections > 0) + self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1) + self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1) + + # Find feasible connection fro three entries + if number_of_incoming == 3: + hole = np.argwhere(incoming_connections < 1)[0][0] + connect_directions = [(hole + 1) % 4, (hole + 2) % 4, (hole + 3) % 4] + self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1) + self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[2], 1) + self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1) + self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[2])), connect_directions[0], 1) + # Make a cross + if number_of_incoming == 4: + connect_directions = np.arange(4) + self.set_transition((rcPos[0], rcPos[1], connect_directions[0]), connect_directions[0], 1) + self.set_transition((rcPos[0], rcPos[1], connect_directions[0]), connect_directions[1], 1) + self.set_transition((rcPos[0], rcPos[1], connect_directions[1]), connect_directions[0], 1) + self.set_transition((rcPos[0], rcPos[1], connect_directions[1]), connect_directions[1], 1) + self.set_transition((rcPos[0], rcPos[1], connect_directions[2]), connect_directions[2], 1) + self.set_transition((rcPos[0], rcPos[1], connect_directions[2]), connect_directions[3], 1) + self.set_transition((rcPos[0], rcPos[1], connect_directions[3]), connect_directions[2], 1) + self.set_transition((rcPos[0], rcPos[1], connect_directions[3]), connect_directions[3], 1) + return True + + +def mirror(dir): + return (dir + 2) % 4 # TODO: improvement override __getitem__ and __setitem__ (cell contents, not transitions?) diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index dedd76b6bfd04c13ad59092adfefbde6ae98fc18..996bd73ad9de598eb162a937c135681675119ad3 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -5,10 +5,8 @@ Generator functions are functions that take width, height and num_resets as argu a GridTransitionMap object. """ -import numpy as np - from flatland.core.grid.grid4_astar import a_star -from flatland.core.grid.grid4_utils import get_direction, mirror, get_new_position +from flatland.core.grid.grid4_utils import get_direction, mirror def connect_rail(rail_trans, rail_array, start, end): @@ -57,79 +55,141 @@ def connect_rail(rail_trans, rail_array, start, end): return path -def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents): +def connect_nodes(rail_trans, rail_array, start, end): + """ + Creates a new path [start,end] in rail_array, based on rail_trans. + """ + # in the worst case we will need to do a A* search, so we might as well set that up + path = a_star(rail_trans, rail_array, start, end) + if len(path) < 2: + return [] + current_dir = get_direction(path[0], path[1]) + end_pos = path[-1] + for index in range(len(path) - 1): + current_pos = path[index] + new_pos = path[index + 1] + new_dir = get_direction(current_pos, new_pos) + + new_trans = rail_array[current_pos] + if index == 0: + if new_trans == 0: + # end-point + # don't set any transition at node yet + new_trans = 0 + else: + # into existing rail + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + else: + # set the forward path + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + # set the backwards path + new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + rail_array[current_pos] = new_trans + + if new_pos == end_pos: + # setup end pos setup + new_trans_e = rail_array[end_pos] + if new_trans_e == 0: + # end-point + # don't set any transition at node yet + + new_trans_e = 0 + else: + # into existing rail + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) + rail_array[end_pos] = new_trans_e + + current_dir = new_dir + return path + + +def connect_from_nodes(rail_trans, rail_array, start, end): + """ + Creates a new path [start,end] in rail_array, based on rail_trans. """ - Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target). + # in the worst case we will need to do a A* search, so we might as well set that up + path = a_star(rail_trans, rail_array, start, end) + if len(path) < 2: + return [] + current_dir = get_direction(path[0], path[1]) + end_pos = path[-1] + for index in range(len(path) - 1): + current_pos = path[index] + new_pos = path[index + 1] + new_dir = get_direction(current_pos, new_pos) + + new_trans = rail_array[current_pos] + if index == 0: + if new_trans == 0: + # end-point + # need to flip direction because of how end points are defined + new_trans = 0 + else: + # into existing rail + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + else: + # set the forward path + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + # set the backwards path + new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + rail_array[current_pos] = new_trans - TODO: add extensive documentation, as users may need this function to simplify their custom level generators. + if new_pos == end_pos: + # setup end pos setup + new_trans_e = rail_array[end_pos] + if new_trans_e == 0: + # end-point + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) + else: + # into existing rail + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) + rail_array[end_pos] = new_trans_e + + current_dir = new_dir + return path + + +def connect_to_nodes(rail_trans, rail_array, start, end): """ + Creates a new path [start,end] in rail_array, based on rail_trans. + """ + # in the worst case we will need to do a A* search, so we might as well set that up + path = a_star(rail_trans, rail_array, start, end) + if len(path) < 2: + return [] + current_dir = get_direction(path[0], path[1]) + end_pos = path[-1] + for index in range(len(path) - 1): + current_pos = path[index] + new_pos = path[index + 1] + new_dir = get_direction(current_pos, new_pos) - def _path_exists(rail, start, direction, end): - # BFS - Check if a path exists between the 2 nodes - - visited = set() - stack = [(start, direction)] - while stack: - node = stack.pop() - if node[0][0] == end[0] and node[0][1] == end[1]: - return 1 - if node not in visited: - visited.add(node) - moves = rail.get_transitions(node[0][0], node[0][1], node[1]) - for move_index in range(4): - if moves[move_index]: - stack.append((get_new_position(node[0], move_index), - move_index)) - - # If cell is a dead-end, append previous node with reversed - # orientation! - nbits = 0 - tmp = rail.get_full_transitions(node[0][0], node[0][1]) - while tmp > 0: - nbits += (tmp & 1) - tmp = tmp >> 1 - if nbits == 1: - stack.append((node[0], (node[1] + 2) % 4)) - - return 0 - - valid_positions = [] - for r in range(rail.height): - for c in range(rail.width): - if rail.get_full_transitions(r, c) > 0: - valid_positions.append((r, c)) - - re_generate = True - while re_generate: - agents_position = [ - valid_positions[i] for i in - np.random.choice(len(valid_positions), num_agents)] - agents_target = [ - valid_positions[i] for i in - np.random.choice(len(valid_positions), num_agents)] - - # agents_direction must be a direction for which a solution is - # guaranteed. - agents_direction = [0] * num_agents - re_generate = False - for i in range(num_agents): - valid_movements = [] - for direction in range(4): - position = agents_position[i] - moves = rail.get_transitions(position[0], position[1], direction) - for move_index in range(4): - if moves[move_index]: - valid_movements.append((direction, move_index)) - - valid_starting_directions = [] - for m in valid_movements: - new_position = get_new_position(agents_position[i], m[1]) - if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0], agents_target[i]): - valid_starting_directions.append(m[0]) - - if len(valid_starting_directions) == 0: - re_generate = True + new_trans = rail_array[current_pos] + if index == 0: + if new_trans == 0: + # end-point + # need to flip direction because of how end points are defined + new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) else: - agents_direction[i] = valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]] + # into existing rail + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + else: + # set the forward path + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + # set the backwards path + new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + rail_array[current_pos] = new_trans - return agents_position, agents_direction, agents_target + if new_pos == end_pos: + # setup end pos setup + new_trans_e = rail_array[end_pos] + if new_trans_e == 0: + # end-point + new_trans_e = 0 + else: + # into existing rail + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) + rail_array[end_pos] = new_trans_e + + current_dir = new_dir + return path diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 4158675cae63394ed768bfc36aaef9cd5f44da7e..c4fed2e07a97b00b786a7db8eb06af247d3ede8a 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -383,8 +383,9 @@ class TreeObsForRailEnv(ObservationBuilder): elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0): conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position) for ca in conflicting_agent[0]: - if direction != self.predicted_dir[pre_step][ca] and cell_transitions[self._reverse_dir( - self.predicted_dir[pre_step][ca])] == 1 and tot_dist < potential_conflict: + if direction != self.predicted_dir[pre_step][ca] \ + and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \ + and tot_dist < potential_conflict: # noqa: E125 potential_conflict = tot_dist if self.env.dones[ca] and tot_dist < potential_conflict: potential_conflict = tot_dist @@ -394,7 +395,8 @@ class TreeObsForRailEnv(ObservationBuilder): conflicting_agent = np.where(self.predicted_pos[post_step] == int_position) for ca in conflicting_agent[0]: if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir( - self.predicted_dir[post_step][ca])] == 1 and tot_dist < potential_conflict: + self.predicted_dir[post_step][ca])] == 1 \ + and tot_dist < potential_conflict: # noqa: E125 potential_conflict = tot_dist if self.env.dones[ca] and tot_dist < potential_conflict: potential_conflict = tot_dist diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 2281282977d8c9d972f13526efa9d96abaf84a52..62efbdc5bc0781c4c7482412dafd98710ed9d14e 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -2,7 +2,7 @@ Definition of the RailEnv environment. """ # TODO: _ this is a global method --> utils or remove later - +import warnings from enum import IntEnum import msgpack @@ -11,9 +11,11 @@ import numpy as np from flatland.core.env import Environment from flatland.core.grid.grid4_utils import get_new_position +from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent -from flatland.envs.generators import random_rail_generator from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.rail_generators import random_rail_generator, RailGenerator +from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator m.patch() @@ -91,7 +93,8 @@ class RailEnv(Environment): def __init__(self, width, height, - rail_generator=random_rail_generator(), + rail_generator: RailGenerator = random_rail_generator(), + schedule_generator: ScheduleGenerator = random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2), max_episode_steps=None, @@ -107,13 +110,12 @@ class RailEnv(Environment): height and agents handles of a rail environment, along with the number of times the env has been reset, and returns a GridTransitionMap object and a list of starting positions, targets, and initial orientations for agent handle. - Implemented functions are: - random_rail_generator : generate a random rail of given size - rail_from_grid_transition_map(rail_map) : generate a rail from - a GridTransitionMap object - rail_from_manual_sp ecifications_generator(rail_spec) : generate a rail from - a rail specifications array - TODO: generate_rail_from_saved_list or from list of ndarray bitmaps --- + The rail_generator can pass a distance map in the hints or information for specific schedule_generators. + Implementations can be found in flatland/envs/rail_generators.py + schedule_generator : function + The schedule_generator function is a function that takes the grid, the number of agents and optional hints + and returns a list of starting positions, targets, initial orientations and speed for all agent handles. + Implementations can be found in flatland/envs/schedule_generators.py width : int The width of the rail map. Potentially in the future, a range of widths to sample from. @@ -131,8 +133,10 @@ class RailEnv(Environment): file_name: you can load a pickle file. """ + self.rail_generator: RailGenerator = rail_generator + self.schedule_generator: ScheduleGenerator = schedule_generator self.rail_generator = rail_generator - self.rail = None + self.rail: GridTransitionMap = None self.width = width self.height = height @@ -213,18 +217,27 @@ class RailEnv(Environment): if replace_agents then regenerate the agents static. Relies on the rail_generator returning agent_static lists (pos, dir, target) """ - tRailAgents = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets) + rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets) - # Check if generator provided a distance map TODO: Make this check safer! - if len(tRailAgents) > 5: - self.obs_builder.distance_map = tRailAgents[-1] + if optionals and 'distance_maps' in optionals: + self.obs_builder.distance_map = optionals['distance_maps'] if regen_rail or self.rail is None: - self.rail = tRailAgents[0] + self.rail = rail self.height, self.width = self.rail.grid.shape + for r in range(self.height): + for c in range(self.width): + rcPos = (r, c) + check = self.rail.cell_neighbours_valid(rcPos, True) + if not check: + warnings.warn("Invalid grid at {} -> {}".format(rcPos, check)) if replace_agents: - self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5]) + agents_hints = None + if optionals and 'agents_hints' in optionals: + agents_hints = optionals['agents_hints'] + self.agents_static = EnvAgentStatic.from_lists( + *self.schedule_generator(self.rail, self.get_num_agents(), hints=agents_hints)) self.restart_agents() @@ -258,8 +271,7 @@ class RailEnv(Environment): agent.malfunction_data['next_malfunction'] -= 1 # Only agents that have a positive rate for malfunctions and are not currently broken are considered - if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data[ - 'malfunction']: + if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data['malfunction']: # If counter has come to zero --> Agent has malfunction # set next malfunction time and duration of current malfunction diff --git a/flatland/envs/generators.py b/flatland/envs/rail_generators.py similarity index 57% rename from flatland/envs/generators.py rename to flatland/envs/rail_generators.py index 355f5502992a34f8d58d4dbd80028eb4dd71cc48..40ec2e0df89a3de48bf0a2a4430de3e30fa556e5 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/rail_generators.py @@ -1,3 +1,7 @@ +"""Rail generators (infrastructure manager, "Infrastrukturbetreiber").""" +import warnings +from typing import Callable, Tuple, Any, Optional + import msgpack import numpy as np @@ -5,29 +9,34 @@ from flatland.core.grid.grid4_utils import get_direction, mirror from flatland.core.grid.grid_utils import distance_on_rail from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap -from flatland.envs.agent_utils import EnvAgentStatic -from flatland.envs.grid4_generators_utils import connect_rail -from flatland.envs.grid4_generators_utils import get_rnd_agents_pos_tgt_dir_on_rail +from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes, connect_from_nodes + +RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Any]] +RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct] -def empty_rail_generator(): +def empty_rail_generator() -> RailGenerator: """ Returns a generator which returns an empty rail mail with no agents. Primarily used by the editor """ - def generator(width, height, num_agents=0, num_resets=0): + def generator(width: int, height: int, num_agents: int = 0, num_resets: int = 0) -> RailGeneratorProduct: rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) rail_array = grid_map.grid rail_array.fill(0) - return grid_map, [], [], [], [] + return grid_map, None return generator -def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=99999, seed=0): +def complex_rail_generator(nr_start_goal=1, + nr_extra=100, + min_dist=20, + max_dist=99999, + seed=0) -> RailGenerator: """ Parameters ------- @@ -47,8 +56,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist= if num_agents > nr_start_goal: num_agents = nr_start_goal print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents") - rail_trans = RailEnvTransitions() - grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) + grid_map = GridTransitionMap(width=width, height=height, transitions=RailEnvTransitions()) rail_array = grid_map.grid rail_array.fill(0) @@ -72,6 +80,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist= # - return transition map + list of [start_pos, start_dir, goal_pos] points # + rail_trans = grid_map.transitions start_goal = [] start_dir = [] nr_created = 0 @@ -141,11 +150,10 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist= if len(new_path) >= 2: nr_created += 1 - agents_position = [sg[0] for sg in start_goal[:num_agents]] - agents_target = [sg[1] for sg in start_goal[:num_agents]] - agents_direction = start_dir[:num_agents] - - return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + return grid_map, {'agents_hints': { + 'start_goal': start_goal, + 'start_dir': start_dir + }} return generator @@ -189,22 +197,18 @@ def rail_from_manual_specifications_generator(rail_spec): effective_transition_cell = rail_env_transitions.rotate_transition(basic_type_of_cell_, rotation_cell_) rail.set_transitions((r, c), effective_transition_cell) - agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail( - rail, - num_agents) - - return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + return [rail, None] return generator -def rail_from_file(filename): +def rail_from_file(filename) -> RailGenerator: """ Utility to load pickle file Parameters ------- - input_file : Pickle file generated by env.save() or editor + filename : Pickle file generated by env.save() or editor Returns ------- @@ -222,26 +226,16 @@ def rail_from_file(filename): grid = np.array(data[b"grid"]) rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions) rail.grid = grid - # agents are always reset as not moving - agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] - # setup with loaded data - agents_position = [a.position for a in agents_static] - agents_direction = [a.direction for a in agents_static] - agents_target = [a.target for a in agents_static] if b"distance_maps" in data.keys(): distance_maps = data[b"distance_maps"] if len(distance_maps) > 0: - return rail, agents_position, agents_direction, agents_target, [1.0] * len( - agents_position), distance_maps - else: - return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) - else: - return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + return rail, {'distance_maps': distance_maps} + return [rail, None] return generator -def rail_from_grid_transition_map(rail_map): +def rail_from_grid_transition_map(rail_map) -> RailGenerator: """ Utility to convert a rail given by a GridTransitionMap map with the correct 16-bit transitions specifications. @@ -257,17 +251,13 @@ def rail_from_grid_transition_map(rail_map): Generator function that always returns the given `rail_map' object. """ - def generator(width, height, num_agents, num_resets=0): - agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail( - rail_map, - num_agents) - - return rail_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct: + return rail_map, None return generator -def random_rail_generator(cell_type_relative_proportion=[1.0] * 11): +def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGenerator: """ Dummy random level generator: - fill in cells at random in [width-2, height-2] @@ -299,7 +289,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11): The matrix with the correct 16-bit bitmaps for each cell. """ - def generator(width, height, num_agents, num_resets=0): + def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct: t_utils = RailEnvTransitions() transition_probability = cell_type_relative_proportion @@ -531,10 +521,277 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11): return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils) return_rail.grid = tmp_rail - agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail( - return_rail, - num_agents) + return return_rail, None + + return generator + + +def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2, min_node_dist=20, node_radius=2, + num_neighb=3, realistic_mode=False, enhance_intersection=False, seed=0): + """ + This is a level generator which generates complex sparse rail configurations + + :param num_cities: Number of city node (can hold trainstations) + :param num_intersections: Number of intersection that city nodes can connect to + :param num_trainstations: Total number of trainstations in env + :param min_node_dist: Minimal distance between nodes + :param node_radius: Proximity of trainstations to center of city node + :param num_neighb: Number of neighbouring nodes each node connects to + :param realistic_mode: True -> NOdes evenly distirbuted in env, False-> Random distribution of nodes + :param enhance_intersection: True -> Extra rail elements added at intersections + :param seed: Random Seed + :return: + ------- + numpy.ndarray of type numpy.uint16 + The matrix with the correct 16-bit bitmaps for each cell. + """ + + def generator(width, height, num_agents, num_resets=0): + + if num_agents > num_trainstations: + num_agents = num_trainstations + warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents") + + rail_trans = RailEnvTransitions() + grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) + rail_array = grid_map.grid + rail_array.fill(0) + np.random.seed(seed + num_resets) + + # Generate a set of nodes for the sparse network + # Try to connect cities to nodes first + node_positions = [] + city_positions = [] + intersection_positions = [] + + # Evenly distribute cities and intersections + if realistic_mode: + tot_num_node = num_intersections + num_cities + nodes_ratio = height / width + nodes_per_row = int(np.ceil(np.sqrt(tot_num_node * nodes_ratio))) + nodes_per_col = int(np.ceil(tot_num_node / nodes_per_row)) + x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int) + y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int) + + for node_idx in range(num_cities + num_intersections): + to_close = True + tries = 0 + if not realistic_mode: + while to_close: + x_tmp = node_radius + np.random.randint(height - node_radius) + y_tmp = node_radius + np.random.randint(width - node_radius) + to_close = False + + # Check distance to cities + for node_pos in city_positions: + if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: + to_close = True + + # CHeck distance to intersections + for node_pos in intersection_positions: + if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: + to_close = True + + if not to_close: + node_positions.append((x_tmp, y_tmp)) + if node_idx < num_cities: + city_positions.append((x_tmp, y_tmp)) + else: + intersection_positions.append((x_tmp, y_tmp)) + tries += 1 + if tries > 100: + warnings.warn("Could not set nodes, please change initial parameters!!!!") + break + else: + x_tmp = x_positions[node_idx % nodes_per_row] + y_tmp = y_positions[node_idx // nodes_per_row] + if len(city_positions) < num_cities and (node_idx % (tot_num_node // num_cities)) == 0: + city_positions.append((x_tmp, y_tmp)) + else: + intersection_positions.append((x_tmp, y_tmp)) + + node_positions = city_positions + intersection_positions + + # Chose node connection + # Set up list of available nodes to connect to + available_nodes_full = np.arange(num_cities + num_intersections) + available_cities = np.arange(num_cities) + available_intersections = np.arange(num_cities, num_cities + num_intersections) + + # Start at some node + current_node = np.random.randint(len(available_nodes_full)) + node_stack = [current_node] + allowed_connections = num_neighb + first_node = True + while len(node_stack) > 0: + current_node = node_stack[0] + delete_idx = np.where(available_nodes_full == current_node) + available_nodes_full = np.delete(available_nodes_full, delete_idx, 0) + # Priority city to intersection connections + if current_node < num_cities and len(available_intersections) > 0: + available_nodes = available_intersections + delete_idx = np.where(available_cities == current_node) + available_cities = np.delete(available_cities, delete_idx, 0) + + # Priority intersection to city connections + elif current_node >= num_cities and len(available_cities) > 0: + available_nodes = available_cities + delete_idx = np.where(available_intersections == current_node) + available_intersections = np.delete(available_intersections, delete_idx, 0) + + # If no options possible connect to whatever node is still available + else: + available_nodes = available_nodes_full + + # Sort available neighbors according to their distance. + node_dist = [] + for av_node in available_nodes: + node_dist.append(distance_on_rail(node_positions[current_node], node_positions[av_node])) + available_nodes = available_nodes[np.argsort(node_dist)] + + # Set number of neighboring nodes + if len(available_nodes) >= allowed_connections: + connected_neighb_idx = available_nodes[:allowed_connections] + else: + connected_neighb_idx = available_nodes + + # Less connections for subsequent nodes + if first_node: + allowed_connections -= 1 + first_node = False + + # Connect to the neighboring nodes + for neighb in connected_neighb_idx: + if neighb not in node_stack: + node_stack.append(neighb) + connect_nodes(rail_trans, rail_array, node_positions[current_node], node_positions[neighb]) + node_stack.pop(0) + + # Place train stations close to the node + # We currently place them uniformly distirbuted among all cities + if num_cities > 1: + train_stations = [[] for i in range(num_cities)] + built_num_trainstation = 0 + spot_found = True + for station in range(num_trainstations): + trainstation_node = int(station / num_trainstations * num_cities) + + station_x = np.clip(node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius), + 0, + height - 1) + station_y = np.clip(node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius), + 0, + width - 1) + tries = 0 + while (station_x, station_y) in train_stations \ + or (station_x, station_y) == node_positions[trainstation_node] \ + or rail_array[(station_x, station_y)] != 0: # noqa: E125 + + station_x = np.clip( + node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius), + 0, + height - 1) + station_y = np.clip( + node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius), + 0, + width - 1) + tries += 1 + if tries > 100: + warnings.warn("Could not set trainstations, please change initial parameters!!!!") + spot_found = False + break + if spot_found: + train_stations[trainstation_node].append((station_x, station_y)) + + # Connect train station to the correct node + connection = connect_from_nodes(rail_trans, rail_array, node_positions[trainstation_node], + (station_x, station_y)) + # Check if connection was made + if len(connection) == 0: + train_stations[trainstation_node].pop(-1) + else: + built_num_trainstation += 1 + + # Adjust the number of agents if you could not build enough trainstations + + if num_agents > built_num_trainstation: + num_agents = built_num_trainstation + warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents") + + # Place passing lanes at intersections + # We currently place them uniformly distirbuted among all cities + if enhance_intersection: + + for intersection in range(num_intersections): + intersect_x_1 = np.clip(intersection_positions[intersection][0] + np.random.randint(1, 3), + 1, + height - 2) + intersect_y_1 = np.clip(intersection_positions[intersection][1] + np.random.randint(-3, 3), + 2, + width - 2) + intersect_x_2 = np.clip( + intersection_positions[intersection][0] + np.random.randint(-3, -1), + 1, + height - 2) + intersect_y_2 = np.clip( + intersection_positions[intersection][1] + np.random.randint(-3, 3), + 1, + width - 2) + + # Connect train station to the correct node + connect_nodes(rail_trans, rail_array, (intersect_x_1, intersect_y_1), + (intersect_x_2, intersect_y_2)) + connect_nodes(rail_trans, rail_array, intersection_positions[intersection], + (intersect_x_1, intersect_y_1)) + connect_nodes(rail_trans, rail_array, intersection_positions[intersection], + (intersect_x_2, intersect_y_2)) + grid_map.fix_transitions((intersect_x_1, intersect_y_1)) + grid_map.fix_transitions((intersect_x_2, intersect_y_2)) + + # Fix all nodes with illegal transition maps + for current_node in node_positions: + grid_map.fix_transitions(current_node) + + # Generate start and target node directory for all agents. + # Assure that start and target are not in the same node + agent_start_targets_nodes = [] + + # Slot availability in node + node_available_start = [] + node_available_target = [] + for node_idx in range(num_cities): + node_available_start.append(len(train_stations[node_idx])) + node_available_target.append(len(train_stations[node_idx])) + + # Assign agents to slots + for agent_idx in range(num_agents): + avail_start_nodes = [idx for idx, val in enumerate(node_available_start) if val > 0] + avail_target_nodes = [idx for idx, val in enumerate(node_available_target) if val > 0] + start_node = np.random.choice(avail_start_nodes) + target_node = np.random.choice(avail_target_nodes) + tries = 0 + found_agent_pair = True + while target_node == start_node: + target_node = np.random.choice(avail_target_nodes) + tries += 1 + # Test again with new start node if no pair is found (This code needs to be improved) + if (tries + 1) % 10 == 0: + start_node = np.random.choice(avail_start_nodes) + if tries > 100: + warnings.warn("Could not set trainstations, removing agent!") + found_agent_pair = False + break + if found_agent_pair: + node_available_start[start_node] -= 1 + node_available_target[target_node] -= 1 + agent_start_targets_nodes.append((start_node, target_node)) + else: + num_agents -= 1 - return return_rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + return grid_map, {'agents_hints': { + 'num_agents': num_agents, + 'agent_start_targets_nodes': agent_start_targets_nodes, + 'train_stations': train_stations + }} return generator diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py new file mode 100644 index 0000000000000000000000000000000000000000..4843e0040d80b79de54e8ed57674a37884ef6809 --- /dev/null +++ b/flatland/envs/schedule_generators.py @@ -0,0 +1,235 @@ +"""Schedule generators (railway undertaking, "EVU").""" +import warnings +from typing import Tuple, List, Callable, Mapping, Optional, Any + +import msgpack +import numpy as np + +from flatland.core.grid.grid4_utils import get_new_position +from flatland.core.transition_map import GridTransitionMap +from flatland.envs.agent_utils import EnvAgentStatic + +AgentPosition = Tuple[int, int] +ScheduleGeneratorProduct = Tuple[List[AgentPosition], List[AgentPosition], List[AgentPosition], List[float]] +ScheduleGenerator = Callable[[GridTransitionMap, int, Optional[Any]], ScheduleGeneratorProduct] + + +def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, float] = None) -> List[float]: + """ + Parameters + ------- + nb_agents : int + The number of agents to generate a speed for + speed_ratio_map : Mapping[float,float] + A map of speeds mappint to their ratio of appearance. The ratios must sum up to 1. + + Returns + ------- + List[float] + A list of size nb_agents of speeds with the corresponding probabilistic ratios. + """ + if speed_ratio_map is None: + return [1.0] * nb_agents + + nb_classes = len(speed_ratio_map.keys()) + speed_ratio_map_as_list: List[Tuple[float, float]] = list(speed_ratio_map.items()) + speed_ratios = list(map(lambda t: t[1], speed_ratio_map_as_list)) + speeds = list(map(lambda t: t[0], speed_ratio_map_as_list)) + return list(map(lambda index: speeds[index], np.random.choice(nb_classes, nb_agents, p=speed_ratios))) + + +def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator: + def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None): + start_goal = hints['start_goal'] + start_dir = hints['start_dir'] + agents_position = [sg[0] for sg in start_goal[:num_agents]] + agents_target = [sg[1] for sg in start_goal[:num_agents]] + agents_direction = start_dir[:num_agents] + + if speed_ratio_map: + speeds = speed_initialization_helper(num_agents, speed_ratio_map) + else: + speeds = [1.0] * len(agents_position) + + return agents_position, agents_direction, agents_target, speeds + + return generator + + +def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator: + def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None): + train_stations = hints['train_stations'] + agent_start_targets_nodes = hints['agent_start_targets_nodes'] + num_agents = hints['num_agents'] + # Place agents and targets within available train stations + agents_position = [] + agents_target = [] + agents_direction = [] + for agent_idx in range(num_agents): + # Set target for agent + current_target_node = agent_start_targets_nodes[agent_idx][1] + target_station_idx = np.random.randint(len(train_stations[current_target_node])) + target = train_stations[current_target_node][target_station_idx] + tries = 0 + while (target[0], target[1]) in agents_target: + target_station_idx = np.random.randint(len(train_stations[current_target_node])) + target = train_stations[current_target_node][target_station_idx] + tries += 1 + if tries > 100: + warnings.warn("Could not set target position, removing an agent") + break + agents_target.append((target[0], target[1])) + + # Set start for agent + current_start_node = agent_start_targets_nodes[agent_idx][0] + start_station_idx = np.random.randint(len(train_stations[current_start_node])) + start = train_stations[current_start_node][start_station_idx] + tries = 0 + while (start[0], start[1]) in agents_position: + tries += 1 + if tries > 100: + warnings.warn("Could not set start position, please change initial parameters!!!!") + break + start_station_idx = np.random.randint(len(train_stations[current_start_node])) + start = train_stations[current_start_node][start_station_idx] + + agents_position.append((start[0], start[1])) + + # Orient the agent correctly + for orientation in range(4): + transitions = rail.get_transitions(start[0], start[1], orientation) + if any(transitions) > 0: + agents_direction.append(orientation) + continue + + if speed_ratio_map: + speeds = speed_initialization_helper(num_agents, speed_ratio_map) + else: + speeds = [1.0] * len(agents_position) + + return agents_position, agents_direction, agents_target, speeds + + return generator + + +def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator: + """ + Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target). + + Parameters + ------- + rail : GridTransitionMap + The railway to place agents on. + num_agents : int + The number of agents to generate a speed for + speed_ratio_map : Mapping[float,float] + A map of speeds mappint to their ratio of appearance. The ratios must sum up to 1. + Returns + ------- + Tuple[List[Tuple[int,int]], List[Tuple[int,int]], List[Tuple[int,int]], List[float]] + initial positions, directions, targets speeds + """ + + def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct: + + valid_positions = [] + for r in range(rail.height): + for c in range(rail.width): + if rail.get_full_transitions(r, c) > 0: + valid_positions.append((r, c)) + if len(valid_positions) == 0: + return [], [], [], [] + + if len(valid_positions) < num_agents: + warnings.warn("schedule_generators: len(valid_positions) < num_agents") + return [], [], [], [] + + agents_position_idx = [i for i in np.random.choice(len(valid_positions), num_agents, replace=False)] + agents_position = [valid_positions[agents_position_idx[i]] for i in range(num_agents)] + agents_target_idx = [i for i in np.random.choice(len(valid_positions), num_agents, replace=False)] + agents_target = [valid_positions[agents_target_idx[i]] for i in range(num_agents)] + update_agents = np.zeros(num_agents) + + re_generate = True + cnt = 0 + while re_generate: + cnt += 1 + if cnt > 1: + print("re_generate cnt={}".format(cnt)) + if cnt > 1000: + raise Exception("After 1000 re_generates still not success, giving up.") + # update position + for i in range(num_agents): + if update_agents[i] == 1: + x = np.setdiff1d(np.arange(len(valid_positions)), agents_position_idx) + agents_position_idx[i] = np.random.choice(x) + agents_position[i] = valid_positions[agents_position_idx[i]] + x = np.setdiff1d(np.arange(len(valid_positions)), agents_target_idx) + agents_target_idx[i] = np.random.choice(x) + agents_target[i] = valid_positions[agents_target_idx[i]] + update_agents = np.zeros(num_agents) + + # agents_direction must be a direction for which a solution is + # guaranteed. + agents_direction = [0] * num_agents + re_generate = False + for i in range(num_agents): + valid_movements = [] + for direction in range(4): + position = agents_position[i] + moves = rail.get_transitions(position[0], position[1], direction) + for move_index in range(4): + if moves[move_index]: + valid_movements.append((direction, move_index)) + + valid_starting_directions = [] + for m in valid_movements: + new_position = get_new_position(agents_position[i], m[1]) + if m[0] not in valid_starting_directions and rail.check_path_exists(new_position, m[1], + agents_target[i]): + valid_starting_directions.append(m[0]) + + if len(valid_starting_directions) == 0: + update_agents[i] = 1 + warnings.warn("reset position for agent[{}]: {} -> {}".format(i, agents_position[i], agents_target[i])) + re_generate = True + break + else: + agents_direction[i] = valid_starting_directions[ + np.random.choice(len(valid_starting_directions), 1)[0]] + + agents_speed = speed_initialization_helper(num_agents, speed_ratio_map) + return agents_position, agents_direction, agents_target, agents_speed + + return generator + + +def schedule_from_file(filename) -> ScheduleGenerator: + """ + Utility to load pickle file + + Parameters + ------- + input_file : Pickle file generated by env.save() or editor + + Returns + ------- + Tuple[List[Tuple[int,int]], List[Tuple[int,int]], List[Tuple[int,int]], List[float]] + initial positions, directions, targets speeds + """ + + def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct: + with open(filename, "rb") as file_in: + load_data = file_in.read() + data = msgpack.unpackb(load_data, use_list=False) + + # agents are always reset as not moving + agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] + # setup with loaded data + agents_position = [a.position for a in agents_static] + agents_direction = [a.direction for a in agents_static] + agents_target = [a.target for a in agents_static] + + return agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + + return generator diff --git a/flatland/evaluators/client.py b/flatland/evaluators/client.py index a4968c0c8e827c060e0e3f7de0cf28cc0658089b..f2dac1c8705a651e2a7be026b4bd82a961efbbdd 100644 --- a/flatland/evaluators/client.py +++ b/flatland/evaluators/client.py @@ -1,18 +1,21 @@ -import redis +import hashlib import json +import logging import os -import numpy as np +import random +import time + import msgpack import msgpack_numpy as m -import hashlib -import random -from flatland.evaluators import messages -from flatland.envs.rail_env import RailEnv -from flatland.envs.generators import rail_from_file +import numpy as np +import redis + from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv -import time -import logging +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import rail_from_file +from flatland.evaluators import messages + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) m.patch() @@ -22,8 +25,8 @@ def are_dicts_equal(d1, d2): """ return True if all keys and values are the same """ return all(k in d2 and d1[k] == d2[k] for k in d1) \ - and all(k in d1 and d1[k] == d2[k] - for k in d2) + and all(k in d1 and d1[k] == d2[k] + for k in d2) class FlatlandRemoteClient(object): @@ -41,39 +44,40 @@ class FlatlandRemoteClient(object): where `service_id` is either provided as an `env` variable or is instantiated to "flatland_rl_redis_service_id" """ - def __init__(self, - remote_host='127.0.0.1', - remote_port=6379, - remote_db=0, - remote_password=None, - test_envs_root=None, - verbose=False): + + def __init__(self, + remote_host='127.0.0.1', + remote_port=6379, + remote_db=0, + remote_password=None, + test_envs_root=None, + verbose=False): self.remote_host = remote_host self.remote_port = remote_port self.remote_db = remote_db self.remote_password = remote_password self.redis_pool = redis.ConnectionPool( - host=remote_host, - port=remote_port, - db=remote_db, - password=remote_password) + host=remote_host, + port=remote_port, + db=remote_db, + password=remote_password) self.namespace = "flatland-rl" self.service_id = os.getenv( - 'FLATLAND_RL_SERVICE_ID', - 'FLATLAND_RL_SERVICE_ID' - ) + 'FLATLAND_RL_SERVICE_ID', + 'FLATLAND_RL_SERVICE_ID' + ) self.command_channel = "{}::{}::commands".format( - self.namespace, - self.service_id - ) + self.namespace, + self.service_id + ) if test_envs_root: self.test_envs_root = test_envs_root else: self.test_envs_root = os.getenv( - 'AICROWD_TESTS_FOLDER', - '/tmp/flatland_envs' - ) + 'AICROWD_TESTS_FOLDER', + '/tmp/flatland_envs' + ) self.verbose = verbose @@ -85,12 +89,12 @@ class FlatlandRemoteClient(object): def _generate_response_channel(self): random_hash = hashlib.md5( - "{}".format( - random.randint(0, 10**10) - ).encode('utf-8')).hexdigest() + "{}".format( + random.randint(0, 10 ** 10) + ).encode('utf-8')).hexdigest() response_channel = "{}::{}::response::{}".format(self.namespace, - self.service_id, - random_hash) + self.service_id, + random_hash) return response_channel def _blocking_request(self, _request): @@ -124,9 +128,9 @@ class FlatlandRemoteClient(object): if self.verbose: print("Response : ", _response) _response = msgpack.unpackb( - _response, - object_hook=m.decode, - encoding="utf8") + _response, + object_hook=m.decode, + encoding="utf8") if _response['type'] == messages.FLATLAND_RL.ERROR: raise Exception(str(_response["payload"])) else: @@ -181,7 +185,7 @@ class FlatlandRemoteClient(object): "Did you remember to set the AICROWD_TESTS_FOLDER environment variable " "to point to the location of the Tests folder ? \n" "We are currently looking at `{}` for the tests".format(self.test_envs_root) - ) + ) print("Current env path : ", test_env_file_path) self.env = RailEnv( width=1, @@ -207,7 +211,7 @@ class FlatlandRemoteClient(object): _request['payload']['action'] = action _response = self._blocking_request(_request) _payload = _response['payload'] - + # remote_observation = _payload['observation'] remote_reward = _payload['reward'] remote_done = _payload['done'] @@ -216,14 +220,14 @@ class FlatlandRemoteClient(object): # Replicate the action in the local env local_observation, local_reward, local_done, local_info = \ self.env.step(action) - + print(local_reward) if not are_dicts_equal(remote_reward, local_reward): raise Exception("local and remote `reward` are diverging") print(remote_reward, local_reward) if not are_dicts_equal(remote_done, local_done): raise Exception("local and remote `done` are diverging") - + # Return local_observation instead of remote_observation # as the remote_observation is build using a dummy observation # builder @@ -250,21 +254,23 @@ class FlatlandRemoteClient(object): if __name__ == "__main__": remote_client = FlatlandRemoteClient() + def my_controller(obs, _env): _action = {} for _idx, _ in enumerate(_env.agents): _action[_idx] = np.random.randint(0, 5) return _action - + + my_observation_builder = TreeObsForRailEnv(max_depth=3, - predictor=ShortestPathPredictorForRailEnv()) + predictor=ShortestPathPredictorForRailEnv()) episode = 0 obs = True - while obs: + while obs: obs = remote_client.env_create( - obs_builder_object=my_observation_builder - ) + obs_builder_object=my_observation_builder + ) if not obs: """ The remote env returns False as the first obs @@ -285,7 +291,5 @@ if __name__ == "__main__": print("Reward : ", sum(list(all_rewards.values()))) break - print("Evaluation Complete...") + print("Evaluation Complete...") print(remote_client.submit()) - - diff --git a/flatland/evaluators/service.py b/flatland/evaluators/service.py index 3ad0a97598c8beb66fc164eb45b670c87f3c96f9..8967b52d9d6ee70a7eb8af257ef6b4e25b531314 100644 --- a/flatland/evaluators/service.py +++ b/flatland/evaluators/service.py @@ -1,24 +1,26 @@ #!/usr/bin/env python from __future__ import print_function -import redis -from flatland.envs.generators import rail_from_file -from flatland.envs.rail_env import RailEnv -from flatland.core.env_observation_builder import DummyObservationBuilder -from flatland.evaluators import messages -from flatland.evaluators import aicrowd_helpers -from flatland.utils.rendertools import RenderTool -import numpy as np -import msgpack -import msgpack_numpy as m -import os + import glob +import os +import random import shutil import time import traceback + import crowdai_api +import msgpack +import msgpack_numpy as m +import numpy as np +import redis import timeout_decorator -import random +from flatland.core.env_observation_builder import DummyObservationBuilder +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import rail_from_file +from flatland.evaluators import aicrowd_helpers +from flatland.evaluators import messages +from flatland.utils.rendertools import RenderTool use_signals_in_timeout = True if os.name == 'nt': @@ -35,7 +37,7 @@ m.patch() ######################################################## # CONSTANTS ######################################################## -PER_STEP_TIMEOUT = 10*60 # 5 minutes +PER_STEP_TIMEOUT = 10 * 60 # 5 minutes class FlatlandRemoteEvaluationService: @@ -59,17 +61,18 @@ class FlatlandRemoteEvaluationService: unpacked with `msgpack` (a patched version of msgpack which also supports numpy arrays). """ + def __init__(self, - test_env_folder="/tmp", - flatland_rl_service_id='FLATLAND_RL_SERVICE_ID', - remote_host='127.0.0.1', - remote_port=6379, - remote_db=0, - remote_password=None, - visualize=False, - video_generation_envs=[], - report=None, - verbose=False): + test_env_folder="/tmp", + flatland_rl_service_id='FLATLAND_RL_SERVICE_ID', + remote_host='127.0.0.1', + remote_port=6379, + remote_db=0, + remote_password=None, + visualize=False, + video_generation_envs=[], + report=None, + verbose=False): # Test Env folder Paths self.test_env_folder = test_env_folder @@ -83,15 +86,15 @@ class FlatlandRemoteEvaluationService: # Logging and Reporting related vars self.verbose = verbose self.report = report - + # Communication Protocol Related vars self.namespace = "flatland-rl" self.service_id = flatland_rl_service_id self.command_channel = "{}::{}::commands".format( - self.namespace, - self.service_id - ) - + self.namespace, + self.service_id + ) + # Message Broker related vars self.remote_host = remote_host self.remote_port = remote_port @@ -114,7 +117,7 @@ class FlatlandRemoteEvaluationService: "normalized_reward": 0.0 } } - + # RailEnv specific variables self.env = False self.env_renderer = False @@ -156,7 +159,7 @@ class FlatlandRemoteEvaluationService:   ├── .......   ├── ....... └── Level_99.pkl - """ + """ env_paths = sorted(glob.glob( os.path.join( self.test_env_folder, @@ -179,16 +182,16 @@ class FlatlandRemoteEvaluationService: """ if self.verbose or self.report: print("Attempting to connect to redis server at {}:{}/{}".format( - self.remote_host, - self.remote_port, - self.remote_db)) + self.remote_host, + self.remote_port, + self.remote_db)) self.redis_pool = redis.ConnectionPool( - host=self.remote_host, - port=self.remote_port, - db=self.remote_db, - password=self.remote_password - ) + host=self.remote_host, + port=self.remote_port, + db=self.remote_db, + password=self.remote_password + ) def get_redis_connection(self): """ @@ -200,13 +203,13 @@ class FlatlandRemoteEvaluationService: redis_conn.ping() except Exception as e: raise Exception( - "Unable to connect to redis server at {}:{} ." - "Are you sure there is a redis-server running at the " - "specified location ?".format( - self.remote_host, - self.remote_port - ) - ) + "Unable to connect to redis server at {}:{} ." + "Are you sure there is a redis-server running at the " + "specified location ?".format( + self.remote_host, + self.remote_port + ) + ) return redis_conn def _error_template(self, payload): @@ -220,8 +223,8 @@ class FlatlandRemoteEvaluationService: return _response @timeout_decorator.timeout( - PER_STEP_TIMEOUT, - use_signals=use_signals_in_timeout) # timeout for each command + PER_STEP_TIMEOUT, + use_signals=use_signals_in_timeout) # timeout for each command def _get_next_command(self, _redis): """ A low level wrapper for obtaining the next command from a @@ -231,7 +234,7 @@ class FlatlandRemoteEvaluationService: """ command = _redis.brpop(self.command_channel)[1] return command - + def get_next_command(self): """ A helper function to obtain the next command, which transparently @@ -246,18 +249,18 @@ class FlatlandRemoteEvaluationService: print("Command Service: ", command) except timeout_decorator.timeout_decorator.TimeoutError: raise Exception( - "Timeout in step {} of simulation {}".format( - self.current_step, - self.simulation_count - )) + "Timeout in step {} of simulation {}".format( + self.current_step, + self.simulation_count + )) command = msgpack.unpackb( - command, - object_hook=m.decode, - encoding="utf8" - ) + command, + object_hook=m.decode, + encoding="utf8" + ) if self.verbose: print("Received Request : ", command) - + return command def send_response(self, _command_response, command, suppress_logs=False): @@ -266,15 +269,15 @@ class FlatlandRemoteEvaluationService: if self.verbose and not suppress_logs: print("Responding with : ", _command_response) - + _redis.rpush( - command_response_channel, + command_response_channel, msgpack.packb( - _command_response, - default=m.encode, + _command_response, + default=m.encode, use_bin_type=True) ) - + def handle_ping(self, command): """ Handles PING command from the client. @@ -313,9 +316,9 @@ class FlatlandRemoteEvaluationService: ) if self.visualize: if self.env_renderer: - del self.env_renderer + del self.env_renderer self.env_renderer = RenderTool(self.env, gl="PILSVG", ) - + # Set max episode steps allowed self.env._max_episode_steps = \ int(1.5 * (self.env.width + self.env.height)) @@ -323,7 +326,7 @@ class FlatlandRemoteEvaluationService: if self.begin_simulation: # If begin simulation has already been initialized # atleast once - self.simulation_times.append(time.time()-self.begin_simulation) + self.simulation_times.append(time.time() - self.begin_simulation) self.begin_simulation = time.time() self.simulation_rewards.append(0) @@ -348,15 +351,15 @@ class FlatlandRemoteEvaluationService: _command_response['type'] = messages.FLATLAND_RL.ENV_CREATE_RESPONSE _command_response['payload'] = {} _command_response['payload']['observation'] = False - _command_response['payload']['env_file_path'] = False + _command_response['payload']['env_file_path'] = False self.send_response(_command_response, command) ##################################################################### # Update evaluation state ##################################################################### progress = np.clip( - self.simulation_count * 1.0 / len(self.env_file_paths), - 0, 1) + self.simulation_count * 1.0 / len(self.env_file_paths), + 0, 1) mean_reward = round(np.mean(self.simulation_rewards), 2) mean_normalized_reward = round(np.mean(self.simulation_rewards_normalized), 2) mean_percentage_complete = round(np.mean(self.simulation_percentage_complete), 3) @@ -399,9 +402,9 @@ class FlatlandRemoteEvaluationService: """ self.simulation_rewards_normalized[-1] += \ cumulative_reward / ( - self.env._max_episode_steps + - self.env.get_num_agents() - ) + self.env._max_episode_steps + + self.env.get_num_agents() + ) if done["__all__"]: # Compute percentage complete @@ -412,14 +415,14 @@ class FlatlandRemoteEvaluationService: complete += 1 percentage_complete = complete * 1.0 / self.env.get_num_agents() self.simulation_percentage_complete[-1] = percentage_complete - + # Record Frame if self.visualize: self.env_renderer.render_env( - show=False, - show_observations=False, - show_predictions=False - ) + show=False, + show_observations=False, + show_predictions=False + ) """ Only save the frames for environments which are separately provided in video_generation_indices param @@ -427,10 +430,10 @@ class FlatlandRemoteEvaluationService: current_env_path = self.env_file_paths[self.simulation_count] if current_env_path in self.video_generation_envs: self.env_renderer.gl.save_image( - os.path.join( - self.vizualization_folder_name, - "flatland_frame_{:04d}.png".format(self.record_frame_step) - )) + os.path.join( + self.vizualization_folder_name, + "flatland_frame_{:04d}.png".format(self.record_frame_step) + )) self.record_frame_step += 1 # Build and send response @@ -453,7 +456,7 @@ class FlatlandRemoteEvaluationService: _payload = command['payload'] # Register simulation time of the last episode - self.simulation_times.append(time.time()-self.begin_simulation) + self.simulation_times.append(time.time() - self.begin_simulation) if len(self.simulation_rewards) != len(self.env_file_paths): raise Exception( @@ -461,7 +464,7 @@ class FlatlandRemoteEvaluationService: to operate on all the test environments. """ ) - + mean_reward = round(np.mean(self.simulation_rewards), 2) mean_normalized_reward = round(np.mean(self.simulation_rewards_normalized), 2) mean_percentage_complete = round(np.mean(self.simulation_percentage_complete), 3) @@ -473,7 +476,7 @@ class FlatlandRemoteEvaluationService: # install it by : # # conda install -c conda-forge x264 ffmpeg - + print("Generating Video from thumbnails...") video_output_path, video_thumb_output_path = \ aicrowd_helpers.generate_movie_from_frames( @@ -518,14 +521,14 @@ class FlatlandRemoteEvaluationService: self.evaluation_state["score"]["score_secondary"] = mean_reward self.evaluation_state["meta"]["normalized_reward"] = mean_normalized_reward self.handle_aicrowd_success_event(self.evaluation_state) - print("#"*100) + print("#" * 100) print("EVALUATION COMPLETE !!") - print("#"*100) + print("#" * 100) print("# Mean Reward : {}".format(mean_reward)) print("# Mean Normalized Reward : {}".format(mean_normalized_reward)) print("# Mean Percentage Complete : {}".format(mean_percentage_complete)) - print("#"*100) - print("#"*100) + print("#" * 100) + print("#" * 100) def report_error(self, error_message, command_response_channel): """ @@ -536,16 +539,16 @@ class FlatlandRemoteEvaluationService: _command_response['type'] = messages.FLATLAND_RL.ERROR _command_response['payload'] = error_message _redis.rpush( - command_response_channel, + command_response_channel, msgpack.packb( - _command_response, - default=m.encode, + _command_response, + default=m.encode, use_bin_type=True) - ) + ) self.evaluation_state["state"] = "ERROR" self.evaluation_state["error"] = error_message self.handle_aicrowd_error_event(self.evaluation_state) - + def handle_aicrowd_info_event(self, payload): self.oracle_events.register_event( event_type=self.oracle_events.CROWDAI_EVENT_INFO, @@ -577,17 +580,17 @@ class FlatlandRemoteEvaluationService: print("Self.Reward : ", self.reward) print("Current Simulation : ", self.simulation_count) if self.env_file_paths and \ - self.simulation_count < len(self.env_file_paths): + self.simulation_count < len(self.env_file_paths): print("Current Env Path : ", - self.env_file_paths[self.simulation_count]) + self.env_file_paths[self.simulation_count]) - try: + try: if command['type'] == messages.FLATLAND_RL.PING: """ INITIAL HANDSHAKE : Respond with PONG """ self.handle_ping(command) - + elif command['type'] == messages.FLATLAND_RL.ENV_CREATE: """ ENV_CREATE @@ -612,8 +615,8 @@ class FlatlandRemoteEvaluationService: self.handle_env_submit(command) else: _error = self._error_template( - "UNKNOWN_REQUEST:{}".format( - str(command))) + "UNKNOWN_REQUEST:{}".format( + str(command))) if self.verbose: print("Responding with : ", _error) self.report_error( @@ -631,10 +634,11 @@ class FlatlandRemoteEvaluationService: if __name__ == "__main__": import argparse + parser = argparse.ArgumentParser(description='Submit the result to AIcrowd') - parser.add_argument('--service_id', - dest='service_id', - default='FLATLAND_RL_SERVICE_ID', + parser.add_argument('--service_id', + dest='service_id', + default='FLATLAND_RL_SERVICE_ID', required=False) parser.add_argument('--test_folder', dest='test_folder', @@ -642,16 +646,16 @@ if __name__ == "__main__": help="Folder containing the files for the test envs", required=False) args = parser.parse_args() - + test_folder = args.test_folder grader = FlatlandRemoteEvaluationService( - test_env_folder=test_folder, - flatland_rl_service_id=args.service_id, - verbose=True, - visualize=True, - video_generation_envs=["Test_0/Level_1.pkl"] - ) + test_env_folder=test_folder, + flatland_rl_service_id=args.service_id, + verbose=True, + visualize=True, + video_generation_envs=["Test_0/Level_1.pkl"] + ) result = grader.run() if result['type'] == messages.FLATLAND_RL.ENV_SUBMIT_RESPONSE: cumulative_results = result['payload'] diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index 69be59ae2a957f6a2aaa948d9830472d7824516a..af1aad222919b00b716dd9da0f3be9534d54e411 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -11,9 +11,9 @@ from numpy import array import flatland.utils.rendertools as rt from flatland.core.grid.grid4_utils import mirror from flatland.envs.agent_utils import EnvAgent, EnvAgentStatic -from flatland.envs.generators import complex_rail_generator, empty_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv, random_rail_generator +from flatland.envs.rail_generators import complex_rail_generator, empty_rail_generator class EditorMVC(object): diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index f86f301cdda4676855ad7a52623822f3053d6ea1..92a0f84f35fa942b03236c6add6e722475a2d842 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -4,7 +4,7 @@ import time import tkinter as tk import numpy as np -from PIL import Image, ImageDraw, ImageTk # , ImageFont +from PIL import Image, ImageDraw, ImageTk, ImageFont from numpy import array from pkg_resources import resource_string as resource_bytes @@ -41,7 +41,7 @@ class PILGL(GraphicsLayer): SELECTED_AGENT_LAYER = 4 SELECTED_TARGET_LAYER = 5 - def __init__(self, width, height, jupyter=False): + def __init__(self, width, height, jupyter=False, screen_width=800, screen_height=600): self.yxBase = (0, 0) self.linewidth = 4 self.n_agent_colors = 1 # overridden in loadAgent @@ -52,13 +52,13 @@ class PILGL(GraphicsLayer): self.background_grid = np.zeros(shape=(self.width, self.height)) if jupyter is False: - # NOTE: Currently removed the dependency on - # screeninfo. We have to find an alternate + # NOTE: Currently removed the dependency on + # screeninfo. We have to find an alternate # way to compute the screen width and height - # In the meantime, we are harcoding the 800x600 + # In the meantime, we are harcoding the 800x600 # assumption - self.screen_width = 800 - self.screen_height = 600 + self.screen_width = screen_width + self.screen_height = screen_height w = (self.screen_width - self.width - 10) / (self.width + 1 + self.linewidth) h = (self.screen_height - self.height - 10) / (self.height + 1 + self.linewidth) self.nPixCell = int(max(1, np.ceil(min(w, h)))) @@ -90,6 +90,8 @@ class PILGL(GraphicsLayer): self.old_background_image = (None, None, None) self.create_layers() + self.font = ImageFont.load_default() + def build_background_map(self, dTargets): x = self.old_background_image rebuild = False @@ -114,7 +116,7 @@ class PILGL(GraphicsLayer): for rc in dTargets: r = rc[1] c = rc[0] - d = int(np.floor(np.sqrt((x - r) ** 2 + (y - c) ** 2))) + d = int(np.floor(np.sqrt((x - r) ** 2 + (y - c) ** 2)) / 0.5) distance = min(d, distance) self.background_grid[x][y] = distance @@ -167,8 +169,14 @@ class PILGL(GraphicsLayer): # quit but not destroy! self.__class__.window.quit() - def text(self, *args, **kwargs): - pass + def text(self, xPx, yPx, strText, layer=RAIL_LAYER): + xyPixLeftTop = (xPx, yPx) + self.draws[layer].text(xyPixLeftTop, strText, font=self.font, fill=(0, 0, 0, 255)) + + def text_rowcol(self, rcTopLeft, strText, layer=AGENT_LAYER): + print("Text:", "rc:", rcTopLeft, "text:", strText, "layer:", layer) + xyPixLeftTop = tuple((array(rcTopLeft) * self.nPixCell)[[1, 0]]) + self.text(*xyPixLeftTop, strText, layer) def prettify(self, *args, **kwargs): pass @@ -263,9 +271,9 @@ class PILGL(GraphicsLayer): class PILSVG(PILGL): - def __init__(self, width, height, jupyter=False): + def __init__(self, width, height, jupyter=False, screen_width=800, screen_height=600): oSuper = super() - oSuper.__init__(width, height, jupyter) + oSuper.__init__(width, height, jupyter, screen_width, screen_height) self.lwAgents = [] self.agents_prev = [] @@ -444,7 +452,7 @@ class PILSVG(PILGL): for transition, file in file_directory.items(): - # Translate the ascii transition description in the format "NE WS" to the + # Translate the ascii transition description in the format "NE WS" to the # binary list of transitions as per RailEnv - NESW (in) x NESW (out) transition_16_bit = ["0"] * 16 for sTran in transition.split(" "): @@ -492,13 +500,17 @@ class PILSVG(PILGL): False)[0] self.draw_image_row_col(colored_rail, (row, col), layer=PILGL.PREDICTION_PATH_LAYER) - def set_rail_at(self, row, col, binary_trans, target=None, is_selected=False, rail_grid=None): + def set_rail_at(self, row, col, binary_trans, target=None, is_selected=False, rail_grid=None, + show_debug=True): + if binary_trans in self.pil_rail: pil_track = self.pil_rail[binary_trans] if target is not None: target_img = self.station_colors[target % len(self.station_colors)] target_img = Image.alpha_composite(pil_track, target_img) self.draw_image_row_col(target_img, (row, col), layer=PILGL.TARGET_LAYER) + if show_debug: + self.text_rowcol((row + 0.8, col + 0.0), strText=str(target), layer=PILGL.TARGET_LAYER) if binary_trans == 0: if self.background_grid[col][row] <= 4: @@ -579,7 +591,7 @@ class PILSVG(PILGL): for color_idx, pil_zug_3 in enumerate(pils): self.pil_zug[(in_direction_2, out_direction_2, color_idx)] = pils[color_idx] - def set_agent_at(self, agent_idx, row, col, in_direction, out_direction, is_selected): + def set_agent_at(self, agent_idx, row, col, in_direction, out_direction, is_selected, show_debug=False): delta_dir = (out_direction - in_direction) % 4 color_idx = agent_idx % self.n_agent_colors # when flipping direction at a dead end, use the "out_direction" direction. @@ -593,6 +605,10 @@ class PILSVG(PILGL): self.clear_layer(PILGL.SELECTED_AGENT_LAYER, 0) self.draw_image_row_col(bg_svg, (row, col), layer=PILGL.SELECTED_AGENT_LAYER) + if show_debug: + print("Call text:") + self.text_rowcol((row + 0.2, col + 0.2,), str(agent_idx)) + def set_cell_occupied(self, agent_idx, row, col): occupied_im = self.cell_occupied[agent_idx % len(self.cell_occupied)] self.draw_image_row_col(occupied_im, (row, col), 1) diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 8974126ad69e35edbd621ba634727120720c9506..802b361b623cdaea08271f5748ac86194056bdf2 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -39,7 +39,10 @@ class RenderTool(object): theta = np.linspace(0, np.pi / 2, 5) arc = array([np.cos(theta), np.sin(theta)]).T # from [1,0] to [0,1] - def __init__(self, env, gl="PILSVG", jupyter=False, agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND): + def __init__(self, env, gl="PILSVG", jupyter=False, + agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, + show_debug=False, screen_width=800, screen_height=600): + self.env = env self.frame_nr = 0 self.start_time = time.time() @@ -48,14 +51,15 @@ class RenderTool(object): self.agent_render_variant = agent_render_variant if gl == "PIL": - self.gl = PILGL(env.width, env.height, jupyter) + self.gl = PILGL(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height) elif gl == "PILSVG": - self.gl = PILSVG(env.width, env.height, jupyter) + self.gl = PILSVG(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height) else: print("[", gl, "] not found, switch to PILSVG") - self.gl = PILSVG(env.width, env.height, jupyter) + self.gl = PILSVG(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height) self.new_rail = True + self.show_debug = show_debug self.update_background() def reset(self): @@ -282,7 +286,7 @@ class RenderTool(object): if len(observation_dict) < 1: warnings.warn( "Predictor did not provide any predicted cells to render. \ - Observaiton builder needs to populate: env.dev_obs_dict") + Observation builder needs to populate: env.dev_obs_dict") else: for agent in agent_handles: color = self.gl.get_agent_color(agent) @@ -525,7 +529,7 @@ class RenderTool(object): is_selected = False self.gl.set_rail_at(r, c, transitions, target=target, is_selected=is_selected, - rail_grid=env.rail.grid) + rail_grid=env.rail.grid, show_debug=self.show_debug) self.gl.build_background_map(targets) @@ -550,7 +554,8 @@ class RenderTool(object): # set_agent_at uses the agent index for the color if self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX: self.gl.set_cell_occupied(agent_idx, *(agent.position)) - self.gl.set_agent_at(agent_idx, *position, old_direction, direction, selected_agent == agent_idx) + self.gl.set_agent_at(agent_idx, *position, old_direction, direction, + selected_agent == agent_idx, show_debug=self.show_debug) else: position = agent.position direction = agent.direction @@ -562,7 +567,7 @@ class RenderTool(object): # set_agent_at uses the agent index for the color self.gl.set_agent_at(agent_idx, *position, agent.direction, direction, - selected_agent == agent_idx) + selected_agent == agent_idx, show_debug=self.show_debug) # set_agent_at uses the agent index for the color if self.agent_render_variant == AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX: diff --git a/flatland/utils/simple_rail.py b/flatland/utils/simple_rail.py index 28978ca3e5b958a51c578f6be5e8c87b77baaa97..67bd93dd35c8f53ef3cdef23dbae0f0d785b9a64 100644 --- a/flatland/utils/simple_rail.py +++ b/flatland/utils/simple_rail.py @@ -2,11 +2,124 @@ from typing import Tuple import numpy as np -from flatland.core.grid.grid4 import Grid4Transitions +from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap def make_simple_rail() -> Tuple[GridTransitionMap, np.array]: + # We instantiate a very simple rail network on a 7x10 grid: + # Note that that cells have invalid RailEnvTransitions! + # | + # | + # | + # _ _ _ _\ _ _ _ _ _ _ + # / + # | + # | + # | + transitions = RailEnvTransitions() + cells = transitions.transition_list + empty = cells[0] + dead_end_from_south = cells[7] + dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) + dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180) + dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) + vertical_straight = cells[1] + horizontal_straight = transitions.rotate_transition(vertical_straight, 90) + simple_switch_north_left = cells[2] + simple_switch_north_right = cells[10] + simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270) + simple_switch_east_west_south = transitions.rotate_transition(simple_switch_north_left, 270) + rail_map = np.array( + [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + + [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + + [[dead_end_from_east] + [horizontal_straight] * 2 + + [simple_switch_east_west_north] + + [horizontal_straight] * 2 + [simple_switch_east_west_south] + + [horizontal_straight] * 2 + [dead_end_from_west]] + + [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + + [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) + rail = GridTransitionMap(width=rail_map.shape[1], + height=rail_map.shape[0], transitions=transitions) + rail.grid = rail_map + return rail, rail_map + + +def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]: + # We instantiate a very simple rail network on a 7x10 grid: + # | + # | + # | + # _ _ _ _\ _ _ _ _ _ _ + # \ + # | + # | + # | + transitions = RailEnvTransitions() + cells = transitions.transition_list + empty = cells[0] + dead_end_from_south = cells[7] + dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) + dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180) + dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) + vertical_straight = cells[1] + horizontal_straight = transitions.rotate_transition(vertical_straight, 90) + simple_switch_north_right = cells[10] + simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270) + simple_switch_west_east_south = transitions.rotate_transition(simple_switch_north_right, 90) + rail_map = np.array( + [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + + [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + + [[dead_end_from_east] + [horizontal_straight] * 2 + + [simple_switch_east_west_north] + + [horizontal_straight] * 2 + [simple_switch_west_east_south] + + [horizontal_straight] * 2 + [dead_end_from_west]] + + [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + + [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) + rail = GridTransitionMap(width=rail_map.shape[1], + height=rail_map.shape[0], transitions=transitions) + rail.grid = rail_map + return rail, rail_map + +def make_simple_rail_unconnected() -> Tuple[GridTransitionMap, np.array]: + # We instantiate a very simple rail network on a 7x10 grid: + # Note that that cells have invalid RailEnvTransitions! + # | + # | + # | + # _ _ _ _ _ _ _ _ _ _ + # / + # | + # | + # | + transitions = RailEnvTransitions() + cells = transitions.transition_list + empty = cells[0] + dead_end_from_south = cells[7] + dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) + dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180) + dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) + vertical_straight = cells[1] + horizontal_straight = transitions.rotate_transition(vertical_straight, 90) + simple_switch_north_left = cells[2] + simple_switch_north_right = cells[10] + simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270) + simple_switch_east_west_south = transitions.rotate_transition(simple_switch_north_left, 270) + rail_map = np.array( + [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + + [[empty] * 3 + [vertical_straight] + [empty] * 6] + + [[empty] * 3 + [dead_end_from_north] + [empty] * 6] + + [[dead_end_from_east] + [horizontal_straight] * 5 + [simple_switch_east_west_south] + + [horizontal_straight] * 2 + [dead_end_from_west]] + + [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + + [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) + rail = GridTransitionMap(width=rail_map.shape[1], + height=rail_map.shape[0], transitions=transitions) + rail.grid = rail_map + return rail, rail_map + + +def make_invalid_simple_rail() -> Tuple[GridTransitionMap, np.array]: # We instantiate a very simple rail network on a 7x10 grid: # | # | @@ -16,15 +129,9 @@ def make_simple_rail() -> Tuple[GridTransitionMap, np.array]: # | # | # | - cells = [int('0000000000000000', 2), # empty cell - Case 0 - int('1000000000100000', 2), # Case 1 - straight - int('1001001000100000', 2), # Case 2 - simple switch - int('1000010000100001', 2), # Case 3 - diamond drossing - int('1001011000100001', 2), # Case 4 - single slip switch - int('1100110000110011', 2), # Case 5 - double slip switch - int('0101001000000010', 2), # Case 6 - symmetrical switch - int('0010000000000000', 2)] # Case 7 - dead end - transitions = Grid4Transitions([]) + transitions = RailEnvTransitions() + cells = transitions.transition_list + empty = cells[0] dead_end_from_south = cells[7] dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) diff --git a/notebooks/simple_example1_env_from_tuple.ipynb b/notebooks/simple_example1_env_from_tuple.ipynb index 3fd55bc8fafabdd57eb43e012c5d98b37c73d496..0fcfe26325e5778cbbfdf66591346972edb2406d 100644 --- a/notebooks/simple_example1_env_from_tuple.ipynb +++ b/notebooks/simple_example1_env_from_tuple.ipynb @@ -14,7 +14,7 @@ "metadata": {}, "outputs": [], "source": [ - "from flatland.envs.generators import rail_from_manual_specifications_generator\n", + "from flatland.envs.rail_generators import rail_from_manual_specifications_generator\n", "from flatland.envs.observations import TreeObsForRailEnv\n", "from flatland.envs.rail_env import RailEnv\n", "from flatland.utils.rendertools import RenderTool\n", diff --git a/notebooks/simple_example2_generate_random_rail.ipynb b/notebooks/simple_example2_generate_random_rail.ipynb index b9d4a96c02d4cb63392654025c6857c8b6764d1a..19b854ee15d8dd1e19361f58a552eba617b19b67 100644 --- a/notebooks/simple_example2_generate_random_rail.ipynb +++ b/notebooks/simple_example2_generate_random_rail.ipynb @@ -15,7 +15,7 @@ "source": [ "import random\n", "import numpy as np\n", - "from flatland.envs.generators import random_rail_generator\n", + "from flatland.envs.rail_generators import random_rail_generator\n", "from flatland.envs.observations import TreeObsForRailEnv\n", "from flatland.envs.rail_env import RailEnv\n", "from flatland.utils.rendertools import RenderTool\n", diff --git a/notebooks/simple_example_3_manual_control.ipynb b/notebooks/simple_example_3_manual_control.ipynb index 50f228055b320c15e0411c2b086254a1f4d4ceef..cb2b377765f375e8d77d8374fdcc3bb67ce06444 100644 --- a/notebooks/simple_example_3_manual_control.ipynb +++ b/notebooks/simple_example_3_manual_control.ipynb @@ -40,7 +40,7 @@ "import random\n", "import numpy as np\n", "import time\n", - "from flatland.envs.generators import random_rail_generator\n", + "from flatland.envs.rail_generators import random_rail_generator\n", "from flatland.envs.observations import TreeObsForRailEnv\n", "from flatland.envs.rail_env import RailEnv\n", "from flatland.utils.rendertools import RenderTool" diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py index 12e0c092a37a475ab6e7dde21c665778e06f5e59..e5e89f76428bb881d0f72aa60aada97ab02167a5 100644 --- a/tests/test_distance_map.py +++ b/tests/test_distance_map.py @@ -1,25 +1,19 @@ import numpy as np -from flatland.core.grid.grid4 import Grid4Transitions +from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap -from flatland.envs.generators import rail_from_grid_transition_map from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.schedule_generators import random_schedule_generator def test_walker(): # _ _ _ - cells = [int('0000000000000000', 2), # empty cell - Case 0 - int('1000000000100000', 2), # Case 1 - straight - int('1001001000100000', 2), # Case 2 - simple switch - int('1000010000100001', 2), # Case 3 - diamond drossing - int('1001011000100001', 2), # Case 4 - single slip switch - int('1100110000110011', 2), # Case 5 - double slip switch - int('0101001000000010', 2), # Case 6 - symmetrical switch - int('0010000000000000', 2)] # Case 7 - dead end - transitions = Grid4Transitions([]) + transitions = RailEnvTransitions() + cells = transitions.transition_list dead_end_from_south = cells[7] dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) @@ -34,6 +28,7 @@ def test_walker(): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(max_depth=10)), diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py index a414231619cfa924c2d33776f9f140cf88280517..8812c847e61d81f6614f37d26489b4c17ea7fd14 100644 --- a/tests/test_flatland_core_transition_map.py +++ b/tests/test_flatland_core_transition_map.py @@ -1,6 +1,13 @@ from flatland.core.grid.grid4 import Grid4Transitions, Grid4TransitionsEnum from flatland.core.grid.grid8 import Grid8Transitions, Grid8TransitionsEnum from flatland.core.transition_map import GridTransitionMap +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.schedule_generators import random_schedule_generator +from flatland.utils.rendertools import RenderTool +from flatland.utils.simple_rail import make_simple_rail, make_simple_rail_unconnected def test_grid4_get_transitions(): @@ -43,4 +50,111 @@ def test_grid8_set_transitions(): grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 0) assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (0, 0, 0, 0, 0, 0, 0, 0) -# TODO GridTransitionMap + +def check_path(env, rail, position, direction, target, expected, rendering=False): + agent = env.agents_static[0] + agent.position = position # south dead-end + agent.direction = direction # north + agent.target = target # east dead-end + agent.moving = True + # reset to set agents from agents_static + # env.reset(False, False) + if rendering: + renderer = RenderTool(env, gl="PILSVG") + renderer.render_env(show=True, show_observations=False) + input("Continue?") + assert rail.check_path_exists(agent.position, agent.direction, agent.target) == expected + + +def test_path_exists(rendering=False): + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + + # reset to initialize agents_static + env.reset() + + check_path( + env, + rail, + (5, 6), # north of south dead-end + 0, # north + (3, 9), # east dead-end + True + ) + + check_path( + env, + rail, + (6, 6), # south dead-end + 2, # south + (3, 9), # east dead-end + True + ) + + check_path( + env, + rail, + (3, 0), # east dead-end + 3, # west + (0, 3), # north dead-end + True + ) + check_path( + env, + rail, + (5, 6), # east dead-end + 0, # west + (1, 3), # north dead-end + True) + + check_path( + env, + rail, + (1,3), # east dead-end + 2, # south + (3,3), # north dead-end + True + ) + + check_path( + env, + rail, + (1,3), # east dead-end + 0, # north + (3,3), # north dead-end + True + ) + + +def test_path_not_exists(rendering=False): + rail, rail_map = make_simple_rail_unconnected() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + + # reset to initialize agents_static + env.reset() + + check_path( + env, + rail, + (5, 6), # south dead-end + 0, # north + (0, 3), # north dead-end + False + ) + + if rendering: + renderer = RenderTool(env, gl="PILSVG") + renderer.render_env(show=True, show_observations=False) + input("Continue?") diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index 574705c49501415f37149bd4b3d870665bf06e60..c96e8db00fe721f42667aed4833d034a47f19156 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -5,10 +5,11 @@ import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.agent_utils import EnvAgent -from flatland.envs.generators import rail_from_grid_transition_map from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions +from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.schedule_generators import random_schedule_generator from flatland.utils.rendertools import RenderTool from flatland.utils.simple_rail import make_simple_rail @@ -21,6 +22,7 @@ def test_global_obs(): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) @@ -90,6 +92,7 @@ def test_reward_function_conflict(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) @@ -168,6 +171,7 @@ def test_reward_function_waiting(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index 7acd58ed0337745f645db6dcc24a70ecb0b64305..09f7e5e67a15c55b5070ac8679e43ecc9a14b9da 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -5,22 +5,24 @@ import pprint import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.envs.generators import rail_from_grid_transition_map from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.schedule_generators import random_schedule_generator from flatland.utils.rendertools import RenderTool -from flatland.utils.simple_rail import make_simple_rail +from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail """Test predictions for `flatland` package.""" def test_dummy_predictor(rendering=False): - rail, rail_map = make_simple_rail() + rail, rail_map = make_simple_rail2() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)), ) @@ -89,7 +91,7 @@ def test_dummy_predictor(rendering=False): expected_actions = np.array([[0.], [2.], [2.], - [1.], + [2.], [2.], [2.], [2.], @@ -111,6 +113,7 @@ def test_shortest_path_predictor(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) @@ -226,10 +229,11 @@ def test_shortest_path_predictor(rendering=False): def test_shortest_path_predictor_conflicts(rendering=False): - rail, rail_map = make_simple_rail() + rail, rail_map = make_invalid_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index 71dc87ceddde986be763491d28dd2b70673632f4..d5dc3ac7af4be6ebd8c5cbeaf705bb710d36d138 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -2,15 +2,15 @@ # -*- coding: utf-8 -*- import numpy as np -from flatland.core.grid.grid4 import Grid4Transitions from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgent from flatland.envs.agent_utils import EnvAgentStatic -from flatland.envs.generators import complex_rail_generator -from flatland.envs.generators import rail_from_grid_transition_map from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import complex_rail_generator +from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator """Tests for `flatland` package.""" @@ -27,6 +27,7 @@ def test_load_env(): def test_save_load(): env = RailEnv(width=10, height=10, rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=0), + schedule_generator=complex_schedule_generator(), number_of_agents=2) env.reset() agent_1_pos = env.agents_static[0].position @@ -49,15 +50,6 @@ def test_save_load(): def test_rail_environment_single_agent(): - cells = [int('0000000000000000', 2), # empty cell - Case 0 - int('1000000000100000', 2), # Case 1 - straight - int('1001001000100000', 2), # Case 2 - simple switch - int('1000010000100001', 2), # Case 3 - diamond drossing - int('1001011000100001', 2), # Case 4 - single slip switch - int('1100110000110011', 2), # Case 5 - double slip switch - int('0101001000000010', 2), # Case 6 - symmetrical switch - int('0010000000000000', 2)] # Case 7 - dead end - # We instantiate the following map on a 3x3 grid # _ _ # / \/ \ @@ -65,6 +57,7 @@ def test_rail_environment_single_agent(): # \_/\_/ transitions = RailEnvTransitions() + cells = transitions.transition_list vertical_line = cells[1] south_symmetrical_switch = cells[6] north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180) @@ -86,6 +79,7 @@ def test_rail_environment_single_agent(): rail_env = RailEnv(width=3, height=3, rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) @@ -139,7 +133,7 @@ test_rail_environment_single_agent() def test_dead_end(): - transitions = Grid4Transitions([]) + transitions = RailEnvTransitions() straight_vertical = int('1000000000100000', 2) # Case 1 - straight straight_horizontal = transitions.rotate_transition(straight_vertical, @@ -165,6 +159,7 @@ def test_dead_end(): rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) @@ -209,6 +204,7 @@ def test_dead_end(): rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..db7cac61f4cf3bec4a330694c1864ef7d82bd076 --- /dev/null +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -0,0 +1,26 @@ +from flatland.envs.observations import GlobalObsForRailEnv +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import sparse_rail_generator +from flatland.envs.schedule_generators import sparse_schedule_generator +from flatland.utils.rendertools import RenderTool + + +def test_sparse_rail_generator(): + env = RailEnv(width=50, + height=50, + rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map + num_intersections=10, # Number of interesections in map + num_trainstations=50, # Number of possible start/targets on map + min_node_dist=6, # Minimal distance of nodes + node_radius=3, # Proximity of stations to city center + num_neighb=3, # Number of connections to other cities + seed=5, # Random seed + realistic_mode=False # Ordered distribution of nodes + ), + schedule_generator=sparse_schedule_generator(), + number_of_agents=10, + obs_builder_object=GlobalObsForRailEnv()) + # reset to initialize agents_static + env_renderer = RenderTool(env, gl="PILSVG", ) + env_renderer.render_env(show=True, show_observations=True, show_predictions=False) + env_renderer.gl.save_image("./sparse_generator_false.png") diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 67dcd25c0769e542fd9a03502c2a8c1b29333b2b..eaf782df3255ecfc6ebaa7078935f485497ed359 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -1,8 +1,9 @@ import numpy as np -from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import complex_rail_generator +from flatland.envs.schedule_generators import complex_schedule_generator class SingleAgentNavigationObs(TreeObsForRailEnv): @@ -62,6 +63,7 @@ def test_malfunction_process(): height=20, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0), + schedule_generator=complex_schedule_generator(), number_of_agents=2, obs_builder_object=SingleAgentNavigationObs(), stochastic_data=stochastic_data) diff --git a/tests/test_flatland_utils_rendertools.py b/tests/test_flatland_utils_rendertools.py index ac5d7f4132b5c224206af3602b6c1341fe026d8b..8248c675995fc5c906e82d8650a5b619e7b038f2 100644 --- a/tests/test_flatland_utils_rendertools.py +++ b/tests/test_flatland_utils_rendertools.py @@ -11,9 +11,9 @@ from importlib_resources import path import flatland.utils.rendertools as rt import images.test -from flatland.envs.generators import empty_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import empty_rail_generator def checkFrozenImage(oRT, sFileImage, resave=False): diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index 8b5468716fa12d83a0546a9b6ff34f2488beace5..8de36c81e4a13c0b7e7e5e556ad79234503ad31a 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -1,10 +1,12 @@ import numpy as np -from flatland.envs.generators import complex_rail_generator from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import complex_rail_generator +from flatland.envs.schedule_generators import complex_schedule_generator np.random.seed(1) + # Use the complex_rail_generator to generate feasible network configurations with corresponding tasks # Training on simple small tasks is the best way to get familiar with the environment # @@ -46,6 +48,7 @@ def test_multi_speed_init(): height=50, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), + schedule_generator=complex_schedule_generator(), number_of_agents=5) # Initialize the agent with the parameters corresponding to the environment and observation_builder agent = RandomAgent(218, 4) @@ -66,7 +69,7 @@ def test_multi_speed_init(): # Run episode for step in range(100): - # Chose an action for each agent in the environment + # Choose an action for each agent in the environment for a in range(env.get_num_agents()): action = agent.act(0) action_dict.update({a: action}) @@ -75,12 +78,11 @@ def test_multi_speed_init(): assert old_pos[a] == env.agents[a].position # Environment step which returns the observations for all agents, their corresponding - # reward and whether their are done + # reward and whether they are done _, _, _, _ = env.step(action_dict) # Update old position whenever an agent was allowed to move for i_agent in range(env.get_num_agents()): if (step + 1) % (i_agent + 1) == 0: - print(step, i_agent, env.agents[a].position) - + print(step, i_agent, env.agents[i_agent].position) old_pos[i_agent] = env.agents[i_agent].position diff --git a/tests/test_speed_classes.py b/tests/test_speed_classes.py new file mode 100644 index 0000000000000000000000000000000000000000..ff5ee56a308ce19559d079b716bde90ad65baf11 --- /dev/null +++ b/tests/test_speed_classes.py @@ -0,0 +1,36 @@ +"""Test speed initialization by a map of speeds and their corresponding ratios.""" +import numpy as np + +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import complex_rail_generator +from flatland.envs.schedule_generators import speed_initialization_helper, complex_schedule_generator + + +def test_speed_initialization_helper(): + np.random.seed(1) + speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.3} + actual_speeds = speed_initialization_helper(10, speed_ratio_map) + + # seed makes speed_initialization_helper deterministic -> check generated speeds. + assert actual_speeds == [2, 3, 1, 2, 1, 1, 1, 2, 2, 2] + + +def test_rail_env_speed_intializer(): + speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.1, 5: 0.2} + + env = RailEnv(width=50, + height=50, + rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, + seed=0), + schedule_generator=complex_schedule_generator(), + number_of_agents=10) + env.reset() + actual_speeds = list(map(lambda agent: agent.speed_data['speed'], env.agents)) + + expected_speed_set = set(speed_ratio_map.keys()) + + # check that the number of speeds generated is correct + assert len(actual_speeds) == env.get_num_agents() + + # check that only the speeds defined are generated + assert all({(actual_speed in expected_speed_set) for actual_speed in actual_speeds}) diff --git a/tests/tests_generators.py b/tests/tests_generators.py index f97b071e6b33c099efa5af36766e159e57716443..8b0480c887a53ade155c28aa6199db3d32f19603 100644 --- a/tests/tests_generators.py +++ b/tests/tests_generators.py @@ -3,11 +3,13 @@ import numpy as np -from flatland.envs.generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \ - random_rail_generator, empty_rail_generator from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \ + random_rail_generator, empty_rail_generator +from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator, \ + schedule_from_file from flatland.utils.simple_rail import make_simple_rail @@ -58,7 +60,8 @@ def test_complex_rail_generator(): env = RailEnv(width=x_dim, height=y_dim, number_of_agents=n_agents, - rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist) + rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist), + schedule_generator=complex_schedule_generator() ) assert env.get_num_agents() == 2 assert env.rail.grid.shape == (y_dim, x_dim) @@ -69,7 +72,8 @@ def test_complex_rail_generator(): env = RailEnv(width=x_dim, height=y_dim, number_of_agents=n_agents, - rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist) + rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist), + schedule_generator=complex_schedule_generator() ) assert env.get_num_agents() == 0 assert env.rail.grid.shape == (y_dim, x_dim) @@ -82,7 +86,8 @@ def test_complex_rail_generator(): env = RailEnv(width=x_dim, height=y_dim, number_of_agents=n_agents, - rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist) + rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist), + schedule_generator=complex_schedule_generator() ) assert env.get_num_agents() == n_agents assert env.rail.grid.shape == (y_dim, x_dim) @@ -94,6 +99,7 @@ def test_rail_from_grid_transition_map(): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=n_agents ) nr_rail_elements = np.count_nonzero(env.rail.grid) @@ -118,6 +124,7 @@ def tests_rail_from_file(): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=3, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) @@ -130,6 +137,7 @@ def tests_rail_from_file(): env = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), + schedule_generator=schedule_from_file(file_name), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) @@ -151,6 +159,7 @@ def tests_rail_from_file(): env2 = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=3, obs_builder_object=GlobalObsForRailEnv(), ) @@ -164,6 +173,7 @@ def tests_rail_from_file(): env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name_2), + schedule_generator=schedule_from_file(file_name_2), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv(), ) @@ -180,6 +190,7 @@ def tests_rail_from_file(): env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), + schedule_generator=schedule_from_file(file_name), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv(), ) @@ -197,6 +208,7 @@ def tests_rail_from_file(): env4 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name_2), + schedule_generator=schedule_from_file(file_name_2), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2), )