diff --git a/flatland/core/env.py b/flatland/core/env.py index 1244aada86233f82dea60f349942df19e92404c0..4a147d067d4510268a431b0e32ea291799ae70a0 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -3,10 +3,14 @@ The env module defines the base Environment class. The base Environment class is adapted from rllib.env.MultiAgentEnv (https://github.com/ray-project/ray). """ +import random + +from .transitions import RailEnvTransitions class Environment: - """Base interface for multi-agent environments in Flatland. + """ + Base interface for multi-agent environments in Flatland. Agents are identified by agent ids (handles). Examples: @@ -89,3 +93,295 @@ class Environment: function. """ raise NotImplementedError() + + +class RailEnv: + """ + RailEnv environment class. + + RailEnv is an environment inspired by a (simplified version of) a rail + network, in which agents (trains) have to navigate to their target + locations in the shortest time possible, while at the same time cooperating + to avoid bottlenecks. + + The valid actions in the environment are: + 0: do nothing + 1: turn left and move to the next cell + 2: move to the next cell in front of the agent + 3: turn righ tand move to the next cell + + Moving forward in a dead-end cell makes the agent turn 180 degrees and step + to the cell it came from. + + The actions of the agents are executed in order of their handle to prevent + deadlocks and to allow them to learn relative priorities. + + TODO: WRITE ABOUT THE REWARD FUNCTION, and possibly allow for alpha and + beta to be passed as parameters to __init__(). + """ + + def __init__(self, rail, number_of_agents=1): + """ + Environment init. + + Parameters + ------- + rail : numpy.ndarray of type numpy.uint16 + The transition matrix that defines the environment. + number_of_agents : int + Number of agents to spawn on the map. + """ + + self.rail = rail + self.width = len(self.rail[0]) + self.height = len(self.rail) + + self.number_of_agents = number_of_agents + + self.actions = [0]*self.number_of_agents + self.rewards = [0]*self.number_of_agents + self.done = False + + self.agents_handles = list(range(self.number_of_agents)) + + self.t_utils = RailEnvTransitions() + # TODO : bad hack for pylint 80 characters per line; shortened function + self.gtfotd = self.t_utils.get_transition_from_orientation_to_direction + + def get_agent_handles(self): + return self.agents_handles + + def reset(self): + self.dones = {"__all__": False} + for handle in self.agents_handles: + self.dones[handle] = False + + re_generate = True + while re_generate: + valid_positions = [] + for r in range(self.height): + for c in range(self.width): + if self.rail[r][c] > 0: + valid_positions.append((r, c)) + + self.agents_position = random.sample(valid_positions, + self.number_of_agents) + self.agents_target = random.sample(valid_positions, + self.number_of_agents) + + # agents_direction must be a direction for which a solution is + # guaranteed. + self.agents_direction = [0]*self.number_of_agents + re_generate = False + for i in range(self.number_of_agents): + valid_movements = [] + for direction in range(4): + position = self.agents_position[i] + moves = self.t_utils.get_transitions_from_orientation( + self.rail[position[0]][position[1]], direction) + for move_index in range(4): + if moves[move_index]: + valid_movements.append((direction, move_index)) + + valid_starting_directions = [] + for m in valid_movements: + new_position = self._new_position(self.agents_position[i], + m[1]) + if m[0] not in valid_starting_directions and \ + self._path_exists(new_position, m[0], + self.agents_target[i]): + valid_starting_directions.append(m[0]) + + if len(valid_starting_directions) == 0: + re_generate = True + else: + self.agents_direction[i] = random.sample( + valid_starting_directions, 1)[0] + + obs_dict = {} + for handle in self.agents_handles: + obs_dict[handle] = self._get_observation_for_agent(handle) + return obs_dict + + def step(self, action_dict): + alpha = 1.0 + beta = 1.0 + + invalid_action_penalty = -2 + step_penalty = -1 * alpha + global_reward = 1 * beta + + # Reset the step rewards + rewards_dict = {} + for handle in self.agents_handles: + rewards_dict[handle] = 0 + + if self.dones["__all__"]: + obs_dict = {} + for handle in self.agents_handles: + obs_dict[handle] = self._get_observation_for_agent(handle) + return obs_dict, rewards_dict, self.dones, {} + + for i in range(len(self.agents_handles)): + handle = self.agents_handles[i] + + if handle not in action_dict: + continue + + action = action_dict[handle] + + if action < 0 or action > 3: + print('ERROR: illegal action=', action, + 'for agent with handle=', handle) + return + + if action > 0: + pos = self.agents_position[i] + direction = self.agents_direction[i] + + movement = direction + if action == 1: + movement = direction - 1 + elif action == 3: + movement = direction + 1 + + if movement < 0: + movement += 4 + if movement >= 4: + movement -= 4 + + if action == 2: + # compute number of possible transitions in the current + # cell + nbits = 0 + tmp = self.rail[pos[0]][pos[1]] + while tmp > 0: + nbits += (tmp & 1) + tmp = tmp >> 1 + if nbits == 1: + # dead-end; assuming the rail network is consistent, + # this should match the direction the agent has come + # from. But it's better to check in any case. + reverse_direction = 0 + if direction == 0: + reverse_direction = 2 + elif direction == 1: + reverse_direction = 3 + elif direction == 2: + reverse_direction = 0 + elif direction == 3: + reverse_direction = 1 + + valid_transition = self.gtfotd( + self.rail[pos[0]][pos[1]], + reverse_direction, + reverse_direction) + if valid_transition: + direction = reverse_direction + movement = direction + + new_position = self._new_position(pos, movement) + + # Is it a legal move? 1) transition allows the movement in the + # cell, 2) the new cell is not empty (case 0), 3) the cell is + # free, i.e., no agent is currently in that cell + if self.rail[new_position[0]][new_position[1]] > 0: + new_cell_isValid = True + else: + new_cell_isValid = False + + transition_isValid = self.gtfotd( + self.rail[pos[0]][pos[1]], + direction, + movement) + + cell_isFree = True + for j in range(self.number_of_agents): + if self.agents_position[j] == new_position: + cell_isFree = False + break + + if new_cell_isValid and transition_isValid and cell_isFree: + # move and change direction to face the movement that was + # performed + self.agents_position[i] = new_position + self.agents_direction[i] = movement + else: + # the action was not valid, add penalty + rewards_dict[handle] += invalid_action_penalty + + # if agent is not in target position, add step penalty + if self.agents_position[i][0] == self.agents_target[i][0] and \ + self.agents_position[i][1] == self.agents_target[i][1]: + self.dones[handle] = True + else: + rewards_dict[handle] += step_penalty + + # Check for end of episode + add global reward to all rewards! + num_agents_in_target_position = 0 + for i in range(self.number_of_agents): + if self.agents_position[i][0] == self.agents_target[i][0] and \ + self.agents_position[i][1] == self.agents_target[i][1]: + num_agents_in_target_position += 1 + + if num_agents_in_target_position == self.number_of_agents: + self.dones["__all__"] = True + rewards_dict = [r+global_reward for r in rewards_dict] + + # Reset the step actions (in case some agent doesn't 'register_action' + # on the next step) + self.actions = [0]*self.number_of_agents + + obs_dict = {} + for handle in self.agents_handles: + obs_dict[handle] = self._get_observation_for_agent(handle) + + return obs_dict, rewards_dict, self.dones, {} + + def _new_position(self, position, movement): + if movement == 0: # NORTH + return (position[0]-1, position[1]) + elif movement == 1: # EAST + return (position[0], position[1] + 1) + elif movement == 2: # SOUTH + return (position[0]+1, position[1]) + elif movement == 3: # WEST + return (position[0], position[1] - 1) + + def _path_exists(self, start, direction, end): + # BFS - Check if a path exists between the 2 nodes + + visited = set() + stack = [(start, direction)] + while stack: + node = stack.pop() + if node[0][0] == end[0] and node[0][1] == end[1]: + return 1 + if node not in visited: + visited.add(node) + moves = self.t_utils.get_transitions_from_orientation( + self.rail[node[0][0]][node[0][1]], node[1]) + for move_index in range(4): + if moves[move_index]: + stack.append((self._new_position(node[0], move_index), + move_index)) + + # If cell is a dead-end, append previous node with reversed + # orientation! + nbits = 0 + tmp = self.rail[node[0][0]][node[0][1]] + while tmp > 0: + nbits += (tmp & 1) + tmp = tmp >> 1 + if nbits == 1: + stack.append((node[0], (node[1] + 2) % 4)) + + return 0 + + def _get_observation_for_agent(self, handle): + # TODO: + return None + + def render(self): + # TODO: + pass diff --git a/flatland/utils/rail_env_generator.py b/flatland/utils/rail_env_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..df180e70b8285c7ae82023636f1dbac283aa7163 --- /dev/null +++ b/flatland/utils/rail_env_generator.py @@ -0,0 +1,305 @@ +""" +The rail_env_generator module defines provides utilities to generate env +bitmaps for the RailEnv environment. +""" +import numpy as np +import random + +from flatland.core.transitions import RailEnvTransitions + + +def generate_rail_from_manual_specifications(rail_spec): + """ + Utility to convert a rail given by manual specification as a map of tuples + (cell_type, rotation), to a transition map with the correct 16-bit + transitions specifications. + + Parameters + ------- + rail_spec : list of list of tuples + List (rows) of lists (columns) of tuples, each specifying a cell for + the RailEnv environment as (cell_type, rotation), with rotation being + clock-wise and in [0, 90, 180, 270]. + + Returns + ------- + numpy.ndarray of type numpy.uint16 + The matrix with the correct 16-bit bitmaps for each cell. + """ + t_utils = RailEnvTransitions() + + height = len(rail_spec) + width = len(rail_spec[0]) + rail = np.zeros((height, width), dtype=np.uint16) + + for r in range(height): + for c in range(width): + cell = rail_spec[r][c] + if cell[0] < 0 or cell[0] >= len(t_utils.transitions): + print("ERROR - invalid cell type=", cell[0]) + return [] + rail[r, c] = t_utils.rotate_transition( + t_utils.transitions[cell[0]], cell[1]) + + return rail + + +def generate_random_rail(width, height): + """ + Dummy random level generator: + - fill in cells at random in [width-2, height-2] + - keep filling cells in among the unfilled ones, such that all transitions + are legit; if no cell can be filled in without violating some + transitions, pick one among those that can satisfy most transitions + (1,2,3 or 4), and delete (+mark to be re-filled) the cells that were + incompatible. + - keep trying for a total number of insertions + (e.g., (W-2)*(H-2)*MAX_REPETITIONS ); if no solution is found, empty the + board and try again from scratch. + - finally pad the border of the map with dead-ends to avoid border issues. + + Dead-ends are not allowed inside the grid, only at the border; however, if + no cell type can be inserted in a given cell (because of the neighboring + transitions), deadends are allowed if they solve the problem. This was + found to turn most un-genereatable levels into valid ones. + + Parameters + ------- + width : int + The width (number of cells) of the grid to generate. + height : int + The height (number of cells) of the grid to generate. + + Returns + ------- + numpy.ndarray of type numpy.uint16 + The matrix with the correct 16-bit bitmaps for each cell. + """ + + t_utils = RailEnvTransitions() + + transitions_templates_ = [] + for i in range(len(t_utils.transitions)-1): # don't include dead-ends + all_transitions = 0 + for dir_ in range(4): + trans = t_utils.get_transitions_from_orientation( + t_utils.transitions[i], dir_) + all_transitions |= (trans[0] << 3) | \ + (trans[1] << 2) | \ + (trans[2] << 1) | \ + (trans[3]) + + template = [int(x) for x in bin(all_transitions)[2:]] + template = [0]*(4-len(template)) + template + + # add all rotations + for rot in [0, 90, 180, 270]: + transitions_templates_.append((template, + t_utils.rotate_transition( + t_utils.transitions[i], + rot))) + template = [template[-1]]+template[:-1] + + def get_matching_templates(template): + ret = [] + for i in range(len(transitions_templates_)): + is_match = True + for j in range(4): + if template[j] >= 0 and \ + template[j] != transitions_templates_[i][0][j]: + is_match = False + break + if is_match: + ret.append(transitions_templates_[i][1]) + return ret + + MAX_INSERTIONS = (width-2) * (height-2) * 10 + MAX_ATTEMPTS_FROM_SCRATCH = 10 + + attempt_number = 0 + while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH: + cells_to_fill = [] + rail = [] + for r in range(height): + rail.append([None]*width) + if r > 0 and r < height-1: + cells_to_fill = cells_to_fill \ + + [(r, c) for c in range(1, width-1)] + + num_insertions = 0 + while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0: + cell = random.sample(cells_to_fill, 1)[0] + cells_to_fill.remove(cell) + row = cell[0] + col = cell[1] + + # look at its neighbors and see what are the possible transitions + # that can be chosen from, if any. + valid_template = [-1, -1, -1, -1] + + for el in [(0, 2, (-1, 0)), + (1, 3, (0, 1)), + (2, 0, (1, 0)), + (3, 1, (0, -1))]: # N, E, S, W + neigh_trans = rail[row+el[2][0]][col+el[2][1]] + if neigh_trans is not None: + # select transition coming from facing direction el[1] and + # moving to direction el[1] + max_bit = 0 + for k in range(4): + max_bit |= \ + t_utils.get_transition_from_orientation_to_direction( + neigh_trans, k, el[1]) + + if max_bit: + valid_template[el[0]] = 1 + else: + valid_template[el[0]] = 0 + + possible_cell_transitions = get_matching_templates(valid_template) + + if len(possible_cell_transitions) == 0: # NO VALID TRANSITIONS + # no cell can be filled in without violating some transitions + # can a dead-end solve the problem? + if valid_template.count(1) == 1: + for k in range(4): + if valid_template[k] == 1: + rot = 0 + if k == 0: + rot = 180 + elif k == 1: + rot = 270 + elif k == 2: + rot = 0 + elif k == 3: + rot = 90 + + rail[row][col] = t_utils.rotate_transition( + int('0000000000100000', 2), rot) + num_insertions += 1 + + break + + else: + # can I get valid transitions by removing a single + # neighboring cell? + bestk = -1 + besttrans = [] + for k in range(4): + tmp_template = valid_template[:] + tmp_template[k] = -1 + possible_cell_transitions = get_matching_templates( + tmp_template) + if len(possible_cell_transitions) > len(besttrans): + besttrans = possible_cell_transitions + bestk = k + + if bestk >= 0: + # Replace the corresponding cell with None, append it + # to cells to fill, fill in a transition in the current + # cell. + replace_row = row - 1 + replace_col = col + if bestk == 1: + replace_row = row + replace_col = col + 1 + elif bestk == 2: + replace_row = row + 1 + replace_col = col + elif bestk == 3: + replace_row = row + replace_col = col - 1 + + cells_to_fill.append((replace_row, replace_col)) + rail[replace_row][replace_col] = None + + rail[row][col] = random.sample( + besttrans, 1)[0] + num_insertions += 1 + + else: + print('WARNING: still nothing!') + rail[row][col] = int('0000000000000000', 2) + num_insertions += 1 + pass + + else: + rail[row][col] = random.sample( + possible_cell_transitions, 1)[0] + num_insertions += 1 + + if num_insertions == MAX_INSERTIONS: + # Failed to generate a valid level; try again for a number of times + attempt_number += 1 + else: + break + + if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH: + print('ERROR: failed to generate level') + + # Finally pad the border of the map with dead-ends to avoid border issues; + # at most 1 transition in the neigh cell + for r in range(height): + # Check for transitions coming from [r][1] to WEST + max_bit = 0 + neigh_trans = rail[r][1] + if neigh_trans is not None: + for k in range(4): + neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ + & (2**4-1) + max_bit = max_bit | (neigh_trans_from_direction & 1) + if max_bit: + rail[r][0] = t_utils.rotate_transition( + int('0000000000100000', 2), 270) + else: + rail[r][0] = int('0000000000000000', 2) + + # Check for transitions coming from [r][-2] to EAST + max_bit = 0 + neigh_trans = rail[r][-2] + if neigh_trans is not None: + for k in range(4): + neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ + & (2**4-1) + max_bit = max_bit | (neigh_trans_from_direction & (1 << 2)) + if max_bit: + rail[r][-1] = t_utils.rotate_transition(int('0000000000100000', 2), + 90) + else: + rail[r][-1] = int('0000000000000000', 2) + + for c in range(width): + # Check for transitions coming from [1][c] to NORTH + max_bit = 0 + neigh_trans = rail[1][c] + if neigh_trans is not None: + for k in range(4): + neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ + & (2**4-1) + max_bit = max_bit | (neigh_trans_from_direction & (1 << 3)) + if max_bit: + rail[0][c] = int('0000000000100000', 2) + else: + rail[0][c] = int('0000000000000000', 2) + + # Check for transitions coming from [-2][c] to SOUTH + max_bit = 0 + neigh_trans = rail[-2][c] + if neigh_trans is not None: + for k in range(4): + neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ + & (2**4-1) + max_bit = max_bit | (neigh_trans_from_direction & (1 << 1)) + if max_bit: + rail[-1][c] = t_utils.rotate_transition( + int('0000000000100000', 2), 180) + else: + rail[-1][c] = int('0000000000000000', 2) + + # For display only, wrong levels + for r in range(height): + for c in range(width): + if rail[r][c] is None: + rail[r][c] = int('0000000000000000', 2) + + return np.asarray(rail, dtype=np.uint16)