diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py new file mode 100644 index 0000000000000000000000000000000000000000..916e50b20b10a02c43c5b1da8bc0728930b8c535 --- /dev/null +++ b/examples/flatland_2_0_example.py @@ -0,0 +1,119 @@ +import numpy as np + +from flatland.envs.generators import sparse_rail_generator +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.rail_env import RailEnv +from flatland.utils.rendertools import RenderTool + +np.random.seed(1) + +# Use the new sparse_rail_generator to generate feasible network configurations with corresponding tasks +# Training on simple small tasks is the best way to get familiar with the environment + +# Use a the malfunction generator to break agents from time to time +stochastic_data = {'prop_malfunction': 0.5, # Percentage of defective agents + 'malfunction_rate': 30, # Rate of malfunction occurence + 'min_duration': 3, # Minimal duration of malfunction + 'max_duration': 10 # Max duration of malfunction + } + +TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) +env = RailEnv(width=20, + height=20, + rail_generator=sparse_rail_generator(num_cities=2, # Number of cities in map (where train stations are) + num_intersections=1, # Number of intersections (no start / target) + num_trainstations=15, # Number of possible start/targets on map + min_node_dist=3, # Minimal distance of nodes + node_radius=3, # Proximity of stations to city center + num_neighb=2, # Number of connections to other cities/intersections + seed=15, # Random seed + realistic_mode=True, + enhance_intersection=True + ), + number_of_agents=5, + stochastic_data=stochastic_data, # Malfunction data generator + obs_builder_object=TreeObservation) + +env_renderer = RenderTool(env, gl="PILSVG", ) + + +# Import your own Agent or use RLlib to train agents on Flatland +# As an example we use a random agent instead +class RandomAgent: + + def __init__(self, state_size, action_size): + self.state_size = state_size + self.action_size = action_size + + def act(self, state): + """ + :param state: input is the observation of the agent + :return: returns an action + """ + return np.random.choice(np.arange(self.action_size)) + + def step(self, memories): + """ + Step function to improve agent by adjusting policy given the observations + + :param memories: SARS Tuple to be + :return: + """ + return + + def save(self, filename): + # Store the current policy + return + + def load(self, filename): + # Load a policy + return + + +# Initialize the agent with the parameters corresponding to the environment and observation_builder +# Set action space to 4 to remove stop action +agent = RandomAgent(218, 4) + + +# Empty dictionary for all agent action +action_dict = dict() + +print("Start episode...") +# Reset environment and get initial observations for all agents +obs = env.reset() +# Update/Set agent's speed +for idx in range(env.get_num_agents()): + speed = 1.0 / ((idx % 5) + 1.0) + env.agents[idx].speed_data["speed"] = speed + +# Reset the rendering sytem +env_renderer.reset() + +# Here you can also further enhance the provided observation by means of normalization +# See training navigation example in the baseline repository + +score = 0 +# Run episode +frame_step = 0 +for step in range(500): + # Chose an action for each agent in the environment + for a in range(env.get_num_agents()): + action = agent.act(obs[a]) + action_dict.update({a: action}) + + # Environment step which returns the observations for all agents, their corresponding + # reward and whether their are done + next_obs, all_rewards, done, _ = env.step(action_dict) + env_renderer.render_env(show=True, show_observations=False, show_predictions=False) + frame_step += 1 + # Update replay buffer and train agent + for a in range(env.get_num_agents()): + agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])) + score += all_rewards[a] + + obs = next_obs.copy() + if done['__all__']: + break + +print('Episode: Steps {}\t Score = {}'.format(step, score)) diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 5e0f6cd72e8ca22c80e3576798e5214f8b036558..7a673bcf9ba46c574db0983e3c52257e4a07358e 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -350,4 +350,124 @@ class GridTransitionMap(TransitionMap): return True + def fix_neighbours(self, rcPos, check_this_cell=False): + """ + Check validity of cell at rcPos = tuple(row, column) + Checks that: + - surrounding cells have inbound transitions for all the + outbound transitions of this cell. + + These are NOT checked - see transition.is_valid: + - all transitions have the mirror transitions (N->E <=> W->S) + - Reverse transitions (N -> S) only exist for a dead-end + - a cell contains either no dead-ends or exactly one + + Returns: True (valid) or False (invalid) + """ + cell_transition = self.grid[tuple(rcPos)] + + if check_this_cell: + if not self.transitions.is_valid(cell_transition): + return False + + gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc] + grcPos = array(rcPos) + grcMax = self.grid.shape + + binTrans = self.get_full_transitions(*rcPos) # 16bit integer - all trans in/out + lnBinTrans = array([binTrans >> 8, binTrans & 0xff], dtype=np.uint8) # 2 x uint8 + g2binTrans = np.unpackbits(lnBinTrans).reshape(4, 4) # 4x4 x uint8 binary(0,1) + gDirOut = g2binTrans.any(axis=0) # outbound directions as boolean array (4) + giDirOut = np.argwhere(gDirOut)[:, 0] # valid outbound directions as array of int + + # loop over available outbound directions (indices) for rcPos + for iDirOut in giDirOut: + gdRC = gDir2dRC[iDirOut] # row,col increment + gPos2 = grcPos + gdRC # next cell in that direction + + # Check the adjacent cell is within bounds + # if not, then this transition is invalid! + if np.any(gPos2 < 0): + return False + if np.any(gPos2 >= grcMax): + return False + + # Get the transitions out of gPos2, using iDirOut as the inbound direction + # if there are no available transitions, ie (0,0,0,0), then rcPos is invalid + t4Trans2 = self.get_transitions(*gPos2, iDirOut) + if any(t4Trans2): + continue + else: + self.set_transition((gPos2[0], gPos2[1], iDirOut), mirror(iDirOut), 1) + return False + + return True + + def fix_transitions(self, rcPos): + """ + Fixes broken transitions + """ + gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc] + grcPos = array(rcPos) + grcMax = self.grid.shape + + # loop over available outbound directions (indices) for rcPos + self.set_transitions(rcPos, 0) + + incoming_connections = np.zeros(4) + for iDirOut in np.arange(4): + gdRC = gDir2dRC[iDirOut] # row,col increment + gPos2 = grcPos + gdRC # next cell in that direction + + # Check the adjacent cell is within bounds + # if not, then ignore it for the count of incoming connections + if np.any(gPos2 < 0): + continue + if np.any(gPos2 >= grcMax): + continue + + # Get the transitions out of gPos2, using iDirOut as the inbound direction + # if there are no available transitions, ie (0,0,0,0), then rcPos is invalid + connected = 0 + for orientation in range(4): + connected += self.get_transition((gPos2[0], gPos2[1], orientation), mirror(iDirOut)) + if connected > 0: + incoming_connections[iDirOut] = 1 + + number_of_incoming = np.sum(incoming_connections) + # Only one incoming direction --> Straight line + if number_of_incoming == 1: + for direction in range(4): + if incoming_connections[direction] > 0: + self.set_transition((rcPos[0], rcPos[1], mirror(direction)), direction, 1) + # Connect all incoming connections + if number_of_incoming == 2: + connect_directions = np.argwhere(incoming_connections > 0) + self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1) + self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1) + + # Find feasible connection fro three entries + if number_of_incoming == 3: + hole = np.argwhere(incoming_connections < 1)[0][0] + connect_directions = [(hole + 1) % 4, (hole + 2) % 4, (hole + 3) % 4] + self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1) + self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[2], 1) + self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1) + self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[2])), connect_directions[0], 1) + # Make a cross + if number_of_incoming == 4: + connect_directions = np.arange(4) + self.set_transition((rcPos[0], rcPos[1], connect_directions[0]), connect_directions[0], 1) + self.set_transition((rcPos[0], rcPos[1], connect_directions[0]), connect_directions[1], 1) + self.set_transition((rcPos[0], rcPos[1], connect_directions[1]), connect_directions[0], 1) + self.set_transition((rcPos[0], rcPos[1], connect_directions[1]), connect_directions[1], 1) + self.set_transition((rcPos[0], rcPos[1], connect_directions[2]), connect_directions[2], 1) + self.set_transition((rcPos[0], rcPos[1], connect_directions[2]), connect_directions[3], 1) + self.set_transition((rcPos[0], rcPos[1], connect_directions[3]), connect_directions[2], 1) + self.set_transition((rcPos[0], rcPos[1], connect_directions[3]), connect_directions[3], 1) + return True + + +def mirror(dir): + return (dir + 2) % 4 # TODO: improvement override __getitem__ and __setitem__ (cell contents, not transitions?) diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index 8bbd0df73a9cb05f4b2794832101bff5442f5d99..d6046d6b8867988d40df30041241f17b685bc83a 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -53,3 +53,221 @@ def connect_rail(rail_trans, rail_array, start, end): current_dir = new_dir return path + + +def connect_nodes(rail_trans, rail_array, start, end): + """ + Creates a new path [start,end] in rail_array, based on rail_trans. + """ + # in the worst case we will need to do a A* search, so we might as well set that up + path = a_star(rail_trans, rail_array, start, end) + if len(path) < 2: + return [] + current_dir = get_direction(path[0], path[1]) + end_pos = path[-1] + for index in range(len(path) - 1): + current_pos = path[index] + new_pos = path[index + 1] + new_dir = get_direction(current_pos, new_pos) + + new_trans = rail_array[current_pos] + if index == 0: + if new_trans == 0: + # end-point + # don't set any transition at node yet + new_trans = 0 + else: + # into existing rail + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + else: + # set the forward path + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + # set the backwards path + new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + rail_array[current_pos] = new_trans + + if new_pos == end_pos: + # setup end pos setup + new_trans_e = rail_array[end_pos] + if new_trans_e == 0: + # end-point + # don't set any transition at node yet + + new_trans_e = 0 + else: + # into existing rail + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) + rail_array[end_pos] = new_trans_e + + current_dir = new_dir + return path + + +def connect_from_nodes(rail_trans, rail_array, start, end): + """ + Creates a new path [start,end] in rail_array, based on rail_trans. + """ + # in the worst case we will need to do a A* search, so we might as well set that up + path = a_star(rail_trans, rail_array, start, end) + if len(path) < 2: + return [] + current_dir = get_direction(path[0], path[1]) + end_pos = path[-1] + for index in range(len(path) - 1): + current_pos = path[index] + new_pos = path[index + 1] + new_dir = get_direction(current_pos, new_pos) + + new_trans = rail_array[current_pos] + if index == 0: + if new_trans == 0: + # end-point + # need to flip direction because of how end points are defined + new_trans = 0 + else: + # into existing rail + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + else: + # set the forward path + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + # set the backwards path + new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + rail_array[current_pos] = new_trans + + if new_pos == end_pos: + # setup end pos setup + new_trans_e = rail_array[end_pos] + if new_trans_e == 0: + # end-point + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) + else: + # into existing rail + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) + rail_array[end_pos] = new_trans_e + + current_dir = new_dir + return path + + +def connect_to_nodes(rail_trans, rail_array, start, end): + """ + Creates a new path [start,end] in rail_array, based on rail_trans. + """ + # in the worst case we will need to do a A* search, so we might as well set that up + path = a_star(rail_trans, rail_array, start, end) + if len(path) < 2: + return [] + current_dir = get_direction(path[0], path[1]) + end_pos = path[-1] + for index in range(len(path) - 1): + current_pos = path[index] + new_pos = path[index + 1] + new_dir = get_direction(current_pos, new_pos) + + new_trans = rail_array[current_pos] + if index == 0: + if new_trans == 0: + # end-point + # need to flip direction because of how end points are defined + new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) + else: + # into existing rail + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + else: + # set the forward path + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + # set the backwards path + new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + rail_array[current_pos] = new_trans + + if new_pos == end_pos: + # setup end pos setup + new_trans_e = rail_array[end_pos] + if new_trans_e == 0: + # end-point + new_trans_e = 0 + else: + # into existing rail + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) + rail_array[end_pos] = new_trans_e + + current_dir = new_dir + return path + + +def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents): + """ + Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target). + + TODO: add extensive documentation, as users may need this function to simplify their custom level generators. + """ + + def _path_exists(rail, start, direction, end): + # BFS - Check if a path exists between the 2 nodes + + visited = set() + stack = [(start, direction)] + while stack: + node = stack.pop() + if node[0][0] == end[0] and node[0][1] == end[1]: + return 1 + if node not in visited: + visited.add(node) + moves = rail.get_transitions(node[0][0], node[0][1], node[1]) + for move_index in range(4): + if moves[move_index]: + stack.append((get_new_position(node[0], move_index), + move_index)) + + # If cell is a dead-end, append previous node with reversed + # orientation! + nbits = 0 + tmp = rail.get_full_transitions(node[0][0], node[0][1]) + while tmp > 0: + nbits += (tmp & 1) + tmp = tmp >> 1 + if nbits == 1: + stack.append((node[0], (node[1] + 2) % 4)) + + return 0 + + valid_positions = [] + for r in range(rail.height): + for c in range(rail.width): + if rail.get_full_transitions(r, c) > 0: + valid_positions.append((r, c)) + + re_generate = True + while re_generate: + agents_position = [ + valid_positions[i] for i in + np.random.choice(len(valid_positions), num_agents)] + agents_target = [ + valid_positions[i] for i in + np.random.choice(len(valid_positions), num_agents)] + + # agents_direction must be a direction for which a solution is + # guaranteed. + agents_direction = [0] * num_agents + re_generate = False + for i in range(num_agents): + valid_movements = [] + for direction in range(4): + position = agents_position[i] + moves = rail.get_transitions(position[0], position[1], direction) + for move_index in range(4): + if moves[move_index]: + valid_movements.append((direction, move_index)) + + valid_starting_directions = [] + for m in valid_movements: + new_position = get_new_position(agents_position[i], m[1]) + if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0], agents_target[i]): + valid_starting_directions.append(m[0]) + + if len(valid_starting_directions) == 0: + re_generate = True + else: + agents_direction[i] = valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]] + + return agents_position, agents_direction, agents_target diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 322bbf096dd1e88e7edab808624c861c6cb408d0..df0ee1f7f91c80fe160120eb527b40196c15d870 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -2,7 +2,7 @@ Definition of the RailEnv environment. """ # TODO: _ this is a global method --> utils or remove later - +import warnings from enum import IntEnum import msgpack @@ -11,6 +11,7 @@ import numpy as np from flatland.core.env import Environment from flatland.core.grid.grid4_utils import get_new_position +from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_generators import random_rail_generator, RailGenerator @@ -134,7 +135,8 @@ class RailEnv(Environment): self.rail_generator: RailGenerator = rail_generator self.agent_generator: ScheduleGenerator = agent_generator - self.rail = None + self.rail_generator = rail_generator + self.rail: GridTransitionMap = None self.width = width self.height = height @@ -223,6 +225,12 @@ class RailEnv(Environment): if regen_rail or self.rail is None: self.rail = rail self.height, self.width = self.rail.grid.shape + for r in range(self.height): + for c in range(self.width): + rcPos = (r, c) + check = self.rail.cell_neighbours_valid(rcPos, True) + if not check: + warnings.warn("Invalid grid at {} -> {}".format(rcPos, check)) if replace_agents: agents_hints = None diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 5e97f1588b38d5e054607d1ba07d7973c212a2f4..d338301c07a0edd82e4821a282af0fc18948f040 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -1,4 +1,5 @@ """Rail generators (infrastructure manager, "Infrastrukturbetreiber").""" +import warnings from typing import Callable, Tuple, Any, Optional import msgpack @@ -8,7 +9,7 @@ from flatland.core.grid.grid4_utils import get_direction, mirror from flatland.core.grid.grid_utils import distance_on_rail from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap -from flatland.envs.grid4_generators_utils import connect_rail +from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes, connect_from_nodes RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Any]] RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct] @@ -523,3 +524,310 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGener return return_rail, None return generator + + +def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2, min_node_dist=20, node_radius=2, + num_neighb=3, realistic_mode=False, enhance_intersection=False, seed=0): + """ + This is a level generator which generates complex sparse rail configurations + + :param num_cities: Number of city node (can hold trainstations) + :param num_intersections: Number of intersection that city nodes can connect to + :param num_trainstations: Total number of trainstations in env + :param min_node_dist: Minimal distance between nodes + :param node_radius: Proximity of trainstations to center of city node + :param num_neighb: Number of neighbouring nodes each node connects to + :param realistic_mode: True -> NOdes evenly distirbuted in env, False-> Random distribution of nodes + :param enhance_intersection: True -> Extra rail elements added at intersections + :param seed: Random Seed + :return: + ------- + numpy.ndarray of type numpy.uint16 + The matrix with the correct 16-bit bitmaps for each cell. + """ + + def generator(width, height, num_agents, num_resets=0): + + if num_agents > num_trainstations: + num_agents = num_trainstations + warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents") + + rail_trans = RailEnvTransitions() + grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) + rail_array = grid_map.grid + rail_array.fill(0) + np.random.seed(seed + num_resets) + + # Generate a set of nodes for the sparse network + # Try to connect cities to nodes first + node_positions = [] + city_positions = [] + intersection_positions = [] + + # Evenly distribute cities and intersections + if realistic_mode: + tot_num_node = num_intersections + num_cities + nodes_ratio = height / width + nodes_per_row = int(np.ceil(np.sqrt(tot_num_node * nodes_ratio))) + nodes_per_col = int(np.ceil(tot_num_node / nodes_per_row)) + x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int) + y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int) + + for node_idx in range(num_cities + num_intersections): + to_close = True + tries = 0 + if not realistic_mode: + while to_close: + x_tmp = node_radius + np.random.randint(height - node_radius) + y_tmp = node_radius + np.random.randint(width - node_radius) + to_close = False + + # Check distance to cities + for node_pos in city_positions: + if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: + to_close = True + + # CHeck distance to intersections + for node_pos in intersection_positions: + if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: + to_close = True + + if not to_close: + node_positions.append((x_tmp, y_tmp)) + if node_idx < num_cities: + city_positions.append((x_tmp, y_tmp)) + else: + intersection_positions.append((x_tmp, y_tmp)) + tries += 1 + if tries > 100: + warnings.warn("Could not set nodes, please change initial parameters!!!!") + break + else: + x_tmp = x_positions[node_idx % nodes_per_row] + y_tmp = y_positions[node_idx // nodes_per_row] + if len(city_positions) < num_cities and (node_idx % (tot_num_node // num_cities)) == 0: + city_positions.append((x_tmp, y_tmp)) + else: + intersection_positions.append((x_tmp, y_tmp)) + + node_positions = city_positions + intersection_positions + + # Chose node connection + # Set up list of available nodes to connect to + available_nodes_full = np.arange(num_cities + num_intersections) + available_cities = np.arange(num_cities) + available_intersections = np.arange(num_cities, num_cities + num_intersections) + + # Start at some node + current_node = np.random.randint(len(available_nodes_full)) + node_stack = [current_node] + allowed_connections = num_neighb + first_node = True + while len(node_stack) > 0: + current_node = node_stack[0] + delete_idx = np.where(available_nodes_full == current_node) + available_nodes_full = np.delete(available_nodes_full, delete_idx, 0) + # Priority city to intersection connections + if current_node < num_cities and len(available_intersections) > 0: + available_nodes = available_intersections + delete_idx = np.where(available_cities == current_node) + available_cities = np.delete(available_cities, delete_idx, 0) + + # Priority intersection to city connections + elif current_node >= num_cities and len(available_cities) > 0: + available_nodes = available_cities + delete_idx = np.where(available_intersections == current_node) + available_intersections = np.delete(available_intersections, delete_idx, 0) + + # If no options possible connect to whatever node is still available + else: + available_nodes = available_nodes_full + + # Sort available neighbors according to their distance. + node_dist = [] + for av_node in available_nodes: + node_dist.append(distance_on_rail(node_positions[current_node], node_positions[av_node])) + available_nodes = available_nodes[np.argsort(node_dist)] + + # Set number of neighboring nodes + if len(available_nodes) >= allowed_connections: + connected_neighb_idx = available_nodes[:allowed_connections] + else: + connected_neighb_idx = available_nodes + + # Less connections for subsequent nodes + if first_node: + allowed_connections -= 1 + first_node = False + + # Connect to the neighboring nodes + for neighb in connected_neighb_idx: + if neighb not in node_stack: + node_stack.append(neighb) + connect_nodes(rail_trans, rail_array, node_positions[current_node], node_positions[neighb]) + node_stack.pop(0) + + # Place train stations close to the node + # We currently place them uniformly distirbuted among all cities + if num_cities > 1: + train_stations = [[] for i in range(num_cities)] + built_num_trainstation = 0 + spot_found = True + for station in range(num_trainstations): + trainstation_node = int(station / num_trainstations * num_cities) + + station_x = np.clip(node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius), + 0, + height - 1) + station_y = np.clip(node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius), + 0, + width - 1) + tries = 0 + while (station_x, station_y) in train_stations or (station_x, station_y) == node_positions[ + trainstation_node] or rail_array[(station_x, station_y)] != 0: + station_x = np.clip( + node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius), + 0, + height - 1) + station_y = np.clip( + node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius), + 0, + width - 1) + tries += 1 + if tries > 100: + warnings.warn("Could not set trainstations, please change initial parameters!!!!") + spot_found = False + break + if spot_found: + train_stations[trainstation_node].append((station_x, station_y)) + + # Connect train station to the correct node + connection = connect_from_nodes(rail_trans, rail_array, node_positions[trainstation_node], + (station_x, station_y)) + # Check if connection was made + if len(connection) == 0: + train_stations[trainstation_node].pop(-1) + else: + built_num_trainstation += 1 + + # Adjust the number of agents if you could not build enough trainstations + + if num_agents > built_num_trainstation: + num_agents = built_num_trainstation + warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents") + + # Place passing lanes at intersections + # We currently place them uniformly distirbuted among all cities + if enhance_intersection: + + for intersection in range(num_intersections): + intersect_x_1 = np.clip(intersection_positions[intersection][0] + np.random.randint(1, 3), + 1, + height - 2) + intersect_y_1 = np.clip(intersection_positions[intersection][1] + np.random.randint(-3, 3), + 2, + width - 2) + intersect_x_2 = np.clip( + intersection_positions[intersection][0] + np.random.randint(-3, -1), + 1, + height - 2) + intersect_y_2 = np.clip( + intersection_positions[intersection][1] + np.random.randint(-3, 3), + 1, + width - 2) + + # Connect train station to the correct node + connect_nodes(rail_trans, rail_array, (intersect_x_1, intersect_y_1), + (intersect_x_2, intersect_y_2)) + connect_nodes(rail_trans, rail_array, intersection_positions[intersection], + (intersect_x_1, intersect_y_1)) + connect_nodes(rail_trans, rail_array, intersection_positions[intersection], + (intersect_x_2, intersect_y_2)) + grid_map.fix_transitions((intersect_x_1, intersect_y_1)) + grid_map.fix_transitions((intersect_x_2, intersect_y_2)) + + # Fix all nodes with illegal transition maps + for current_node in node_positions: + grid_map.fix_transitions(current_node) + + # Generate start and target node directory for all agents. + # Assure that start and target are not in the same node + agent_start_targets_nodes = [] + + # Slot availability in node + node_available_start = [] + node_available_target = [] + for node_idx in range(num_cities): + node_available_start.append(len(train_stations[node_idx])) + node_available_target.append(len(train_stations[node_idx])) + + # Assign agents to slots + for agent_idx in range(num_agents): + avail_start_nodes = [idx for idx, val in enumerate(node_available_start) if val > 0] + avail_target_nodes = [idx for idx, val in enumerate(node_available_target) if val > 0] + start_node = np.random.choice(avail_start_nodes) + target_node = np.random.choice(avail_target_nodes) + tries = 0 + found_agent_pair = True + while target_node == start_node: + target_node = np.random.choice(avail_target_nodes) + tries += 1 + # Test again with new start node if no pair is found (This code needs to be improved) + if (tries + 1) % 10 == 0: + start_node = np.random.choice(avail_start_nodes) + if tries > 100: + warnings.warn("Could not set trainstations, removing agent!") + found_agent_pair = False + break + if found_agent_pair: + node_available_start[start_node] -= 1 + node_available_target[target_node] -= 1 + agent_start_targets_nodes.append((start_node, target_node)) + else: + num_agents -= 1 + + # Place agents and targets within available train stations + agents_position = [] + agents_target = [] + agents_direction = [] + + for agent_idx in range(num_agents): + # Set target for agent + current_target_node = agent_start_targets_nodes[agent_idx][1] + target_station_idx = np.random.randint(len(train_stations[current_target_node])) + target = train_stations[current_target_node][target_station_idx] + tries = 0 + while (target[0], target[1]) in agents_target: + target_station_idx = np.random.randint(len(train_stations[current_target_node])) + target = train_stations[current_target_node][target_station_idx] + tries += 1 + if tries > 100: + warnings.warn("Could not set target position, removing an agent") + break + agents_target.append((target[0], target[1])) + + # Set start for agent + current_start_node = agent_start_targets_nodes[agent_idx][0] + start_station_idx = np.random.randint(len(train_stations[current_start_node])) + start = train_stations[current_start_node][start_station_idx] + tries = 0 + while (start[0], start[1]) in agents_position: + tries += 1 + if tries > 100: + warnings.warn("Could not set start position, please change initial parameters!!!!") + break + start_station_idx = np.random.randint(len(train_stations[current_start_node])) + start = train_stations[current_start_node][start_station_idx] + + agents_position.append((start[0], start[1])) + + # Orient the agent correctly + for orientation in range(4): + transitions = grid_map.get_transitions(start[0], start[1], orientation) + if any(transitions) > 0: + agents_direction.append(orientation) + continue + + return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + + return generator diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index 42cbfeaa87b95b47bb9ca122a131bc8ea86eeb64..6a0a9282614c0319338454f5b8ae97531b12e432 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -41,7 +41,7 @@ class PILGL(GraphicsLayer): SELECTED_AGENT_LAYER = 4 SELECTED_TARGET_LAYER = 5 - def __init__(self, width, height, jupyter=False): + def __init__(self, width, height, jupyter=False, screen_width=800, screen_height=600): self.yxBase = (0, 0) self.linewidth = 4 self.n_agent_colors = 1 # overridden in loadAgent @@ -52,13 +52,13 @@ class PILGL(GraphicsLayer): self.background_grid = np.zeros(shape=(self.width, self.height)) if jupyter is False: - # NOTE: Currently removed the dependency on - # screeninfo. We have to find an alternate + # NOTE: Currently removed the dependency on + # screeninfo. We have to find an alternate # way to compute the screen width and height - # In the meantime, we are harcoding the 800x600 + # In the meantime, we are harcoding the 800x600 # assumption - self.screen_width = 800 - self.screen_height = 600 + self.screen_width = screen_width + self.screen_height = screen_height w = (self.screen_width - self.width - 10) / (self.width + 1 + self.linewidth) h = (self.screen_height - self.height - 10) / (self.height + 1 + self.linewidth) self.nPixCell = int(max(1, np.ceil(min(w, h)))) @@ -116,7 +116,7 @@ class PILGL(GraphicsLayer): for rc in dTargets: r = rc[1] c = rc[0] - d = int(np.floor(np.sqrt((x - r) ** 2 + (y - c) ** 2))) + d = int(np.floor(np.sqrt((x - r) ** 2 + (y - c) ** 2)) / 0.5) distance = min(d, distance) self.background_grid[x][y] = distance @@ -271,9 +271,9 @@ class PILGL(GraphicsLayer): class PILSVG(PILGL): - def __init__(self, width, height, jupyter=False): + def __init__(self, width, height, jupyter=False, screen_width=800, screen_height=600): oSuper = super() - oSuper.__init__(width, height, jupyter) + oSuper.__init__(width, height, jupyter, screen_width, screen_height) self.lwAgents = [] self.agents_prev = [] @@ -452,7 +452,7 @@ class PILSVG(PILGL): for transition, file in file_directory.items(): - # Translate the ascii transition description in the format "NE WS" to the + # Translate the ascii transition description in the format "NE WS" to the # binary list of transitions as per RailEnv - NESW (in) x NESW (out) transition_16_bit = ["0"] * 16 for sTran in transition.split(" "): diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 99d301667097a57a7fbdb77cd185d54838eca26a..802b361b623cdaea08271f5748ac86194056bdf2 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -40,8 +40,9 @@ class RenderTool(object): arc = array([np.cos(theta), np.sin(theta)]).T # from [1,0] to [0,1] def __init__(self, env, gl="PILSVG", jupyter=False, - agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, - show_debug=True): + agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, + show_debug=False, screen_width=800, screen_height=600): + self.env = env self.frame_nr = 0 self.start_time = time.time() @@ -50,12 +51,12 @@ class RenderTool(object): self.agent_render_variant = agent_render_variant if gl == "PIL": - self.gl = PILGL(env.width, env.height, jupyter) + self.gl = PILGL(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height) elif gl == "PILSVG": - self.gl = PILSVG(env.width, env.height, jupyter) + self.gl = PILSVG(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height) else: print("[", gl, "] not found, switch to PILSVG") - self.gl = PILSVG(env.width, env.height, jupyter) + self.gl = PILSVG(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height) self.new_rail = True self.show_debug = show_debug @@ -554,7 +555,7 @@ class RenderTool(object): if self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX: self.gl.set_cell_occupied(agent_idx, *(agent.position)) self.gl.set_agent_at(agent_idx, *position, old_direction, direction, - selected_agent == agent_idx, show_debug=self.show_debug) + selected_agent == agent_idx, show_debug=self.show_debug) else: position = agent.position direction = agent.direction diff --git a/flatland/utils/simple_rail.py b/flatland/utils/simple_rail.py index 28978ca3e5b958a51c578f6be5e8c87b77baaa97..c5fe4860783f242f21c97c55a9119d8918454a96 100644 --- a/flatland/utils/simple_rail.py +++ b/flatland/utils/simple_rail.py @@ -2,11 +2,87 @@ from typing import Tuple import numpy as np -from flatland.core.grid.grid4 import Grid4Transitions +from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap def make_simple_rail() -> Tuple[GridTransitionMap, np.array]: + # We instantiate a very simple rail network on a 7x10 grid: + # Note that that cells have invalid RailEnvTransitions! + # | + # | + # | + # _ _ _ _\ _ _ _ _ _ _ + # / + # | + # | + # | + transitions = RailEnvTransitions() + cells = transitions.transition_list + empty = cells[0] + dead_end_from_south = cells[7] + dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) + dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180) + dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) + vertical_straight = cells[1] + horizontal_straight = transitions.rotate_transition(vertical_straight, 90) + simple_switch_north_left = cells[2] + simple_switch_north_right = cells[10] + simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270) + simple_switch_east_west_south = transitions.rotate_transition(simple_switch_north_left, 270) + rail_map = np.array( + [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + + [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + + [[dead_end_from_east] + [horizontal_straight] * 2 + + [simple_switch_east_west_north] + + [horizontal_straight] * 2 + [simple_switch_east_west_south] + + [horizontal_straight] * 2 + [dead_end_from_west]] + + [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + + [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) + rail = GridTransitionMap(width=rail_map.shape[1], + height=rail_map.shape[0], transitions=transitions) + rail.grid = rail_map + return rail, rail_map + + +def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]: + # We instantiate a very simple rail network on a 7x10 grid: + # | + # | + # | + # _ _ _ _\ _ _ _ _ _ _ + # \ + # | + # | + # | + transitions = RailEnvTransitions() + cells = transitions.transition_list + empty = cells[0] + dead_end_from_south = cells[7] + dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) + dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180) + dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) + vertical_straight = cells[1] + horizontal_straight = transitions.rotate_transition(vertical_straight, 90) + simple_switch_north_right = cells[10] + simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270) + simple_switch_west_east_south = transitions.rotate_transition(simple_switch_north_right, 90) + rail_map = np.array( + [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + + [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + + [[dead_end_from_east] + [horizontal_straight] * 2 + + [simple_switch_east_west_north] + + [horizontal_straight] * 2 + [simple_switch_west_east_south] + + [horizontal_straight] * 2 + [dead_end_from_west]] + + [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + + [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) + rail = GridTransitionMap(width=rail_map.shape[1], + height=rail_map.shape[0], transitions=transitions) + rail.grid = rail_map + return rail, rail_map + + +def make_invalid_simple_rail() -> Tuple[GridTransitionMap, np.array]: # We instantiate a very simple rail network on a 7x10 grid: # | # | @@ -16,15 +92,9 @@ def make_simple_rail() -> Tuple[GridTransitionMap, np.array]: # | # | # | - cells = [int('0000000000000000', 2), # empty cell - Case 0 - int('1000000000100000', 2), # Case 1 - straight - int('1001001000100000', 2), # Case 2 - simple switch - int('1000010000100001', 2), # Case 3 - diamond drossing - int('1001011000100001', 2), # Case 4 - single slip switch - int('1100110000110011', 2), # Case 5 - double slip switch - int('0101001000000010', 2), # Case 6 - symmetrical switch - int('0010000000000000', 2)] # Case 7 - dead end - transitions = Grid4Transitions([]) + transitions = RailEnvTransitions() + cells = transitions.transition_list + empty = cells[0] dead_end_from_south = cells[7] dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py index 566505b7fa5c8389fd31e90ec565a5728c10a2b0..583830f79146372bfaa8a7d0739688d88781d38b 100644 --- a/tests/test_distance_map.py +++ b/tests/test_distance_map.py @@ -1,6 +1,6 @@ import numpy as np -from flatland.core.grid.grid4 import Grid4Transitions +from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv @@ -12,15 +12,8 @@ from flatland.envs.schedule_generators import get_rnd_agents_pos_tgt_dir_on_rail def test_walker(): # _ _ _ - cells = [int('0000000000000000', 2), # empty cell - Case 0 - int('1000000000100000', 2), # Case 1 - straight - int('1001001000100000', 2), # Case 2 - simple switch - int('1000010000100001', 2), # Case 3 - diamond drossing - int('1001011000100001', 2), # Case 4 - single slip switch - int('1100110000110011', 2), # Case 5 - double slip switch - int('0101001000000010', 2), # Case 6 - symmetrical switch - int('0010000000000000', 2)] # Case 7 - dead end - transitions = Grid4Transitions([]) + transitions = RailEnvTransitions() + cells = transitions.transition_list dead_end_from_south = cells[7] dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index b58be79bc9db0cee5c2a523c965ec6cd96f18e52..65894b505c000098bd5dc798e9ac0cfd1aed09e1 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -11,13 +11,13 @@ from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.schedule_generators import get_rnd_agents_pos_tgt_dir_on_rail from flatland.utils.rendertools import RenderTool -from flatland.utils.simple_rail import make_simple_rail +from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail """Test predictions for `flatland` package.""" def test_dummy_predictor(rendering=False): - rail, rail_map = make_simple_rail() + rail, rail_map = make_simple_rail2() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], @@ -91,7 +91,7 @@ def test_dummy_predictor(rendering=False): expected_actions = np.array([[0.], [2.], [2.], - [1.], + [2.], [2.], [2.], [2.], @@ -229,7 +229,7 @@ def test_shortest_path_predictor(rendering=False): def test_shortest_path_predictor_conflicts(rendering=False): - rail, rail_map = make_simple_rail() + rail, rail_map = make_invalid_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index a812e01f5ed8e2f78be0411f0a29667c01981876..656059ac243dfe0cd5386648603688ec0de4546b 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -2,7 +2,6 @@ # -*- coding: utf-8 -*- import numpy as np -from flatland.core.grid.grid4 import Grid4Transitions from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgent @@ -51,15 +50,6 @@ def test_save_load(): def test_rail_environment_single_agent(): - cells = [int('0000000000000000', 2), # empty cell - Case 0 - int('1000000000100000', 2), # Case 1 - straight - int('1001001000100000', 2), # Case 2 - simple switch - int('1000010000100001', 2), # Case 3 - diamond drossing - int('1001011000100001', 2), # Case 4 - single slip switch - int('1100110000110011', 2), # Case 5 - double slip switch - int('0101001000000010', 2), # Case 6 - symmetrical switch - int('0010000000000000', 2)] # Case 7 - dead end - # We instantiate the following map on a 3x3 grid # _ _ # / \/ \ @@ -67,6 +57,7 @@ def test_rail_environment_single_agent(): # \_/\_/ transitions = RailEnvTransitions() + cells = transitions.transition_list vertical_line = cells[1] south_symmetrical_switch = cells[6] north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180) @@ -142,7 +133,7 @@ test_rail_environment_single_agent() def test_dead_end(): - transitions = Grid4Transitions([]) + transitions = RailEnvTransitions() straight_vertical = int('1000000000100000', 2) # Case 1 - straight straight_horizontal = transitions.rotate_transition(straight_vertical, diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..5f0cc81e660f885f2d5d6ee2357d85da7d7c903c --- /dev/null +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -0,0 +1,26 @@ +from flatland.envs.observations import GlobalObsForRailEnv +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import sparse_rail_generator +from flatland.envs.schedule_generators import sparse_rail_generator_agents_placer +from flatland.utils.rendertools import RenderTool + + +def test_sparse_rail_generator(): + env = RailEnv(width=50, + height=50, + rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map + num_intersections=10, # Number of interesections in map + num_trainstations=50, # Number of possible start/targets on map + min_node_dist=6, # Minimal distance of nodes + node_radius=3, # Proximity of stations to city center + num_neighb=3, # Number of connections to other cities + seed=5, # Random seed + realistic_mode=False # Ordered distribution of nodes + ), + agent_generator=sparse_rail_generator_agents_placer(), + number_of_agents=10, + obs_builder_object=GlobalObsForRailEnv()) + # reset to initialize agents_static + env_renderer = RenderTool(env, gl="PILSVG", ) + env_renderer.render_env(show=True, show_observations=True, show_predictions=False) + env_renderer.gl.save_image("./sparse_generator_false.png")