diff --git a/examples/temporary_example.py b/examples/temporary_example.py index 6bed439cb21f611353763da36a890739c77e5866..28a94db3754fc21783447cf15057022476189a63 100644 --- a/examples/temporary_example.py +++ b/examples/temporary_example.py @@ -54,8 +54,6 @@ env.agents_direction[0] = 1 # TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too! env.obs_builder.reset() """ - - """ # INFINITE-LOOP TEST specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)], @@ -84,6 +82,7 @@ env.obs_builder.reset() env = RailEnv(width=7, height=7, rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), + # rail_generator=complex_rail_generator(nr_start_goal=2), number_of_agents=2) # Print the distance map of each cell to the target of the first agent diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py index 382bdd7d9c483aa491f9c11b06b62b13b45490a5..ac29408051efd4bc85f8120b02f7dfd59b14dee9 100644 --- a/flatland/envs/env_utils.py +++ b/flatland/envs/env_utils.py @@ -5,7 +5,7 @@ Definition of the RailEnv environment and related level-generation functions. Generator functions are functions that take width, height and num_resets as arguments and return a GridTransitionMap object. """ -# import numpy as np +import numpy as np # from flatland.core.env import Environment # from flatland.core.env_observation_builder import TreeObsForRailEnv @@ -271,3 +271,93 @@ def connect_rail(rail_trans, rail_array, start, end): def distance_on_rail(pos1, pos2): return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1]) + + +def get_new_position(position, movement): + if movement == 0: # NORTH + return (position[0] - 1, position[1]) + elif movement == 1: # EAST + return (position[0], position[1] + 1) + elif movement == 2: # SOUTH + return (position[0] + 1, position[1]) + elif movement == 3: # WEST + return (position[0], position[1] - 1) + + +def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents): + """ + Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target). + + TODO: add extensive documentation, as users may need this function to simplify their custom level generators. + """ + + def _path_exists(rail, start, direction, end): + # BFS - Check if a path exists between the 2 nodes + + visited = set() + stack = [(start, direction)] + while stack: + node = stack.pop() + if node[0][0] == end[0] and node[0][1] == end[1]: + return 1 + if node not in visited: + visited.add(node) + moves = rail.get_transitions((node[0][0], node[0][1], node[1])) + for move_index in range(4): + if moves[move_index]: + stack.append((get_new_position(node[0], move_index), + move_index)) + + # If cell is a dead-end, append previous node with reversed + # orientation! + nbits = 0 + tmp = rail.get_transitions((node[0][0], node[0][1])) + while tmp > 0: + nbits += (tmp & 1) + tmp = tmp >> 1 + if nbits == 1: + stack.append((node[0], (node[1] + 2) % 4)) + + return 0 + + valid_positions = [] + for r in range(rail.height): + for c in range(rail.width): + if rail.get_transitions((r, c)) > 0: + valid_positions.append((r, c)) + + re_generate = True + while re_generate: + agents_position = [ + valid_positions[i] for i in + np.random.choice(len(valid_positions), num_agents)] + agents_target = [ + valid_positions[i] for i in + np.random.choice(len(valid_positions), num_agents)] + + # agents_direction must be a direction for which a solution is + # guaranteed. + agents_direction = [0] * num_agents + re_generate = False + for i in range(num_agents): + valid_movements = [] + for direction in range(4): + position = agents_position[i] + moves = rail.get_transitions((position[0], position[1], direction)) + for move_index in range(4): + if moves[move_index]: + valid_movements.append((direction, move_index)) + + valid_starting_directions = [] + for m in valid_movements: + new_position = get_new_position(agents_position[i], m[1]) + if m[0] not in valid_starting_directions and \ + _path_exists(rail, new_position, m[0], agents_target[i]): + valid_starting_directions.append(m[0]) + + if len(valid_starting_directions) == 0: + re_generate = True + else: + agents_direction[i] = valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]] + + return agents_position, agents_direction, agents_target diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 021c63f31a958790c04e17f49c7a1da57777cb9c..2c26076dd527c9e315366471274b20f021502881 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -5,7 +5,7 @@ import numpy as np from flatland.core.transitions import Grid8Transitions, RailEnvTransitions from flatland.core.transition_map import GridTransitionMap -from flatland.envs.env_utils import distance_on_rail, connect_rail +from flatland.envs.env_utils import distance_on_rail, connect_rail, get_rnd_agents_pos_tgt_dir_on_rail def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): @@ -23,7 +23,7 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): The matrix with the correct 16-bit bitmaps for each cell. """ - def generator(width, height, num_resets=0): + def generator(width, height, agents_handles, num_resets=0): rail_trans = RailEnvTransitions() rail_array = np.zeros(shape=(width, height), dtype=np.uint16) @@ -106,8 +106,12 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): return_rail = GridTransitionMap(width=width, height=height, transitions=rail_trans) return_rail.grid = rail_array - # TODO: return start_goal - return return_rail + + # TODO: return agents_position, agents_direction and agents_target! + # NOTE: the initial direction must be such that the target can be reached. + # See env_utils.get_rnd_agents_pos_tgt_dir_on_rail() for hints, if required. + + return return_rail, [], [], [] return generator @@ -132,7 +136,7 @@ def rail_from_manual_specifications_generator(rail_spec): the matrix of correct 16-bit bitmaps for each cell. """ - def generator(width, height, num_resets=0): + def generator(width, height, agents_handles, num_resets=0): t_utils = RailEnvTransitions() height = len(rail_spec) @@ -147,7 +151,11 @@ def rail_from_manual_specifications_generator(rail_spec): return [] rail.set_transitions((r, c), t_utils.rotate_transition(t_utils.transitions[cell[0]], cell[1])) - return rail + agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail( + rail, + len(agents_handles)) + + return rail, agents_position, agents_direction, agents_target return generator @@ -168,8 +176,12 @@ def rail_from_GridTransitionMap_generator(rail_map): Generator function that always returns the given `rail_map' object. """ - def generator(width, height, num_resets=0): - return rail_map + def generator(width, height, agents_handles, num_resets=0): + agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail( + rail_map, + len(agents_handles)) + + return rail_map, agents_position, agents_direction, agents_target return generator @@ -189,7 +201,7 @@ def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames): Generator function that always returns the given `rail_map' object. """ - def generator(width, height, num_resets=0): + def generator(width, height, agents_handles, num_resets=0): t_utils = RailEnvTransitions() rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils) rail_map.load_transition_map(list_of_filenames[num_resets % len(list_of_filenames)], override_gridsize=False) @@ -197,7 +209,11 @@ def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames): if rail_map.grid.dtype == np.uint64: rail_map.transitions = Grid8Transitions() - return rail_map + agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail( + rail_map, + len(agents_handles)) + + return rail_map, agents_position, agents_direction, agents_target return generator @@ -243,7 +259,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): The matrix with the correct 16-bit bitmaps for each cell. """ - def generator(width, height, num_resets=0): + def generator(width, height, agents_handles, num_resets=0): t_utils = RailEnvTransitions() transition_probability = cell_type_relative_proportion @@ -472,6 +488,11 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils) return_rail.grid = tmp_rail - return return_rail + + agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail( + return_rail, + len(agents_handles)) + + return return_rail, agents_position, agents_direction, agents_target return generator diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index cbb11c2db70dece00199bfae6bb051504f64a9f2..98abf81f469e1ea329db39e86c3cfe0a7756df28 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -9,6 +9,7 @@ import numpy as np from flatland.core.env import Environment from flatland.core.env_observation_builder import TreeObsForRailEnv from flatland.envs.generators import random_rail_generator +from flatland.envs.env_utils import get_new_position # from flatland.core.transitions import Grid8Transitions, RailEnvTransitions # from flatland.core.transition_map import GridTransitionMap @@ -74,9 +75,10 @@ class RailEnv(Environment): Parameters ------- rail_generator : function - The rail_generator function is a function that takes the width and - height of a rail map along with the number of times the env has - been reset, and returns a GridTransitionMap object. + The rail_generator function is a function that takes the width, + height and agents handles of a rail environment, along with the number of times + the env has been reset, and returns a GridTransitionMap object and a list of + starting positions, targets, and initial orientations for agent handle. Implemented functions are: random_rail_generator : generate a random rail of given size rail_from_GridTransitionMap_generator(rail_map) : generate a rail from @@ -130,107 +132,16 @@ class RailEnv(Environment): def get_agent_handles(self): return self.agents_handles - def fill_valid_positions(self): - ''' Populate the valid_positions list for the current TransitionMap. - ''' - self.valid_positions = valid_positions = [] - for r in range(self.height): - for c in range(self.width): - if self.rail.get_transitions((r, c)) > 0: - valid_positions.append((r, c)) - - def check_agent_lists(self): - ''' Check that the agent_handles, position and direction lists are all of length - number_of_agents. - (Suggest this is replaced with a single list of Agent objects :) - ''' - for lAgents, name in zip( - [self.agents_handles, self.agents_position, self.agents_direction], - ["handles", "positions", "directions"]): - assert self.number_of_agents == len(lAgents), "Inconsistent agent list:" + name - - def check_agent_locdirpath(self, iAgent): - ''' Check that agent iAgent has a valid location and direction, - with a path to its target. - (Not currently used?) - ''' - valid_movements = [] - for direction in range(4): - position = self.agents_position[iAgent] - moves = self.rail.get_transitions((position[0], position[1], direction)) - for move_index in range(4): - if moves[move_index]: - valid_movements.append((direction, move_index)) - - valid_starting_directions = [] - for m in valid_movements: - new_position = self._new_position(self.agents_position[iAgent], m[1]) - if m[0] not in valid_starting_directions and \ - self._path_exists(new_position, m[0], self.agents_target[iAgent]): - valid_starting_directions.append(m[0]) - - if len(valid_starting_directions) == 0: - return False - else: - return True - - def pick_agent_direction(self, rcPos, rcTarget): - """ Pick and return a valid direction index (0..3) for an agent starting at - row,col rcPos with target rcTarget. - Return None if no path exists. - Picks random direction if more than one exists (uniformly). + def reset(self, regen_rail=True, replace_agents=True): """ - valid_movements = [] - for direction in range(4): - moves = self.rail.get_transitions((*rcPos, direction)) - for move_index in range(4): - if moves[move_index]: - valid_movements.append((direction, move_index)) - # print("pos", rcPos, "targ", rcTarget, "valid movements", valid_movements) - - valid_starting_directions = [] - for m in valid_movements: - new_position = self._new_position(rcPos, m[1]) - if m[0] not in valid_starting_directions and self._path_exists(new_position, m[0], rcTarget): - valid_starting_directions.append(m[0]) - - if len(valid_starting_directions) == 0: - return None - else: - return valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]] - - def add_agent(self, rcPos=None, rcTarget=None, iDir=None): - """ Add a new agent at position rcPos with target rcTarget and - initial direction index iDir. - Should also store this initial position etc as environment "meta-data" - but this does not yet exist. + TODO: replace_agents is ignored at the moment; agents will always be replaced. """ - self.check_agent_lists() - - if rcPos is None: - rcPos = np.random.choice(len(self.valid_positions)) - - iAgent = self.number_of_agents - - if iDir is None: - iDir = self.pick_agent_direction(rcPos, rcTarget) - if iDir is None: - print("Error picking agent direction at pos:", rcPos) - return None - - self.agents_position.append(tuple(rcPos)) # ensure it's a tuple not a list - self.agents_handles.append(max(self.agents_handles + [-1]) + 1) # max(handles) + 1, starting at 0 - self.agents_direction.append(iDir) - self.agents_target.append(rcPos) # set the target to the origin initially - self.number_of_agents += 1 - self.check_agent_lists() - return iAgent - - def reset(self, regen_rail=True, replace_agents=True): if regen_rail or self.rail is None: - # TODO: Import not only rail information but also start and goal positions - self.rail = self.rail_generator(self.width, self.height, self.num_resets) - self.fill_valid_positions() + self.rail, self.agents_position, self.agents_direction, self.agents_target = self.rail_generator( + self.width, + self.height, + self.agents_handles, + self.num_resets) self.num_resets += 1 @@ -238,37 +149,6 @@ class RailEnv(Environment): for handle in self.agents_handles: self.dones[handle] = False - # Use a TreeObsForRailEnv to compute distance maps to each agent's target, to sample initial - # agent's orientations that allow a valid solution. - # TODO: Possibility ot fill valid positions from list of goals and start - self.fill_valid_positions() - - if replace_agents: - re_generate = True - while re_generate: - - # self.agents_position = random.sample(valid_positions, - # self.number_of_agents) - self.agents_position = [ - self.valid_positions[i] for i in - np.random.choice(len(self.valid_positions), self.number_of_agents)] - self.agents_target = [ - self.valid_positions[i] for i in - np.random.choice(len(self.valid_positions), self.number_of_agents)] - - # agents_direction must be a direction for which a solution is - # guaranteed. - self.agents_direction = [0] * self.number_of_agents - re_generate = False - - for i in range(self.number_of_agents): - direction = self.pick_agent_direction(self.agents_position[i], self.agents_target[i]) - if direction is None: - re_generate = True - break - else: - self.agents_direction[i] = direction - # Reset the state of the observation builder with the new environment self.obs_builder.reset() @@ -342,7 +222,7 @@ class RailEnv(Environment): movement = np.argmax(possible_transitions) transition_isValid = True - new_position = self._new_position(pos, movement) + new_position = get_new_position(pos, movement) # Is it a legal move? 1) transition allows the movement in the # cell, 2) the new cell is not empty (case 0), 3) the cell is # free, i.e., no agent is currently in that cell @@ -401,45 +281,6 @@ class RailEnv(Environment): self.actions = [0] * self.number_of_agents return self._get_observations(), self.rewards_dict, self.dones, {} - def _new_position(self, position, movement): - if movement == 0: # NORTH - return (position[0] - 1, position[1]) - elif movement == 1: # EAST - return (position[0], position[1] + 1) - elif movement == 2: # SOUTH - return (position[0] + 1, position[1]) - elif movement == 3: # WEST - return (position[0], position[1] - 1) - - def _path_exists(self, start, direction, end): - # BFS - Check if a path exists between the 2 nodes - - visited = set() - stack = [(start, direction)] - while stack: - node = stack.pop() - if node[0][0] == end[0] and node[0][1] == end[1]: - return 1 - if node not in visited: - visited.add(node) - moves = self.rail.get_transitions((node[0][0], node[0][1], node[1])) - for move_index in range(4): - if moves[move_index]: - stack.append((self._new_position(node[0], move_index), - move_index)) - - # If cell is a dead-end, append previous node with reversed - # orientation! - nbits = 0 - tmp = self.rail.get_transitions((node[0][0], node[0][1])) - while tmp > 0: - nbits += (tmp & 1) - tmp = tmp >> 1 - if nbits == 1: - stack.append((node[0], (node[1] + 2) % 4)) - - return 0 - def _get_observations(self): self.obs_dict = {} for handle in self.agents_handles: