From 15679dc6895d2dbd4cdf6b7f23e25ac41a87304d Mon Sep 17 00:00:00 2001
From: spiglerg <>
Date: Thu, 4 Apr 2019 13:56:50 +0200
Subject: [PATCH] added very preliminary RailEnv in env, and rail_env_generator
 utils in utils

 flatland/core/                 | 298 +++++++++++++++++++++++++-
 flatland/utils/ | 305 +++++++++++++++++++++++++++
 2 files changed, 602 insertions(+), 1 deletion(-)
 create mode 100644 flatland/utils/

diff --git a/flatland/core/ b/flatland/core/
index 1244aada..4a147d06 100644
--- a/flatland/core/
+++ b/flatland/core/
@@ -3,10 +3,14 @@ The env module defines the base Environment class.
 The base Environment class is adapted from rllib.env.MultiAgentEnv
+import random
+from .transitions import RailEnvTransitions
 class Environment:
-    """Base interface for multi-agent environments in Flatland.
+    """
+    Base interface for multi-agent environments in Flatland.
     Agents are identified by agent ids (handles).
@@ -89,3 +93,295 @@ class Environment:
         raise NotImplementedError()
+class RailEnv:
+    """
+    RailEnv environment class.
+    RailEnv is an environment inspired by a (simplified version of) a rail
+    network, in which agents (trains) have to navigate to their target
+    locations in the shortest time possible, while at the same time cooperating
+    to avoid bottlenecks.
+    The valid actions in the environment are:
+        0: do nothing
+        1: turn left and move to the next cell
+        2: move to the next cell in front of the agent
+        3: turn righ tand move to the next cell
+    Moving forward in a dead-end cell makes the agent turn 180 degrees and step
+    to the cell it came from.
+    The actions of the agents are executed in order of their handle to prevent
+    deadlocks and to allow them to learn relative priorities.
+    TODO: WRITE ABOUT THE REWARD FUNCTION, and possibly allow for alpha and
+    beta to be passed as parameters to __init__().
+    """
+    def __init__(self, rail, number_of_agents=1):
+        """
+        Environment init.
+        Parameters
+        -------
+        rail : numpy.ndarray of type numpy.uint16
+            The transition matrix that defines the environment.
+        number_of_agents : int
+            Number of agents to spawn on the map.
+        """
+        self.rail = rail
+        self.width = len(self.rail[0])
+        self.height = len(self.rail)
+        self.number_of_agents = number_of_agents
+        self.actions = [0]*self.number_of_agents
+        self.rewards = [0]*self.number_of_agents
+        self.done = False
+        self.agents_handles = list(range(self.number_of_agents))
+        self.t_utils = RailEnvTransitions()
+        # TODO : bad hack for pylint 80 characters per line; shortened function
+        self.gtfotd = self.t_utils.get_transition_from_orientation_to_direction
+    def get_agent_handles(self):
+        return self.agents_handles
+    def reset(self):
+        self.dones = {"__all__": False}
+        for handle in self.agents_handles:
+            self.dones[handle] = False
+        re_generate = True
+        while re_generate:
+            valid_positions = []
+            for r in range(self.height):
+                for c in range(self.width):
+                    if self.rail[r][c] > 0:
+                        valid_positions.append((r, c))
+            self.agents_position = random.sample(valid_positions,
+                                                 self.number_of_agents)
+            self.agents_target = random.sample(valid_positions,
+                                               self.number_of_agents)
+            # agents_direction must be a direction for which a solution is
+            # guaranteed.
+            self.agents_direction = [0]*self.number_of_agents
+            re_generate = False
+            for i in range(self.number_of_agents):
+                valid_movements = []
+                for direction in range(4):
+                    position = self.agents_position[i]
+                    moves = self.t_utils.get_transitions_from_orientation(
+                             self.rail[position[0]][position[1]], direction)
+                    for move_index in range(4):
+                        if moves[move_index]:
+                            valid_movements.append((direction, move_index))
+                valid_starting_directions = []
+                for m in valid_movements:
+                    new_position = self._new_position(self.agents_position[i],
+                                                      m[1])
+                    if m[0] not in valid_starting_directions and \
+                       self._path_exists(new_position, m[0],
+                                         self.agents_target[i]):
+                        valid_starting_directions.append(m[0])
+                if len(valid_starting_directions) == 0:
+                    re_generate = True
+                else:
+                    self.agents_direction[i] = random.sample(
+                                               valid_starting_directions, 1)[0]
+        obs_dict = {}
+        for handle in self.agents_handles:
+            obs_dict[handle] = self._get_observation_for_agent(handle)
+        return obs_dict
+    def step(self, action_dict):
+        alpha = 1.0
+        beta = 1.0
+        invalid_action_penalty = -2
+        step_penalty = -1 * alpha
+        global_reward = 1 * beta
+        # Reset the step rewards
+        rewards_dict = {}
+        for handle in self.agents_handles:
+            rewards_dict[handle] = 0
+        if self.dones["__all__"]:
+            obs_dict = {}
+            for handle in self.agents_handles:
+                obs_dict[handle] = self._get_observation_for_agent(handle)
+            return obs_dict, rewards_dict, self.dones, {}
+        for i in range(len(self.agents_handles)):
+            handle = self.agents_handles[i]
+            if handle not in action_dict:
+                continue
+            action = action_dict[handle]
+            if action < 0 or action > 3:
+                print('ERROR: illegal action=', action,
+                      'for agent with handle=', handle)
+                return
+            if action > 0:
+                pos = self.agents_position[i]
+                direction = self.agents_direction[i]
+                movement = direction
+                if action == 1:
+                    movement = direction - 1
+                elif action == 3:
+                    movement = direction + 1
+                if movement < 0:
+                    movement += 4
+                if movement >= 4:
+                    movement -= 4
+                if action == 2:
+                    # compute number of possible transitions in the current
+                    # cell
+                    nbits = 0
+                    tmp = self.rail[pos[0]][pos[1]]
+                    while tmp > 0:
+                        nbits += (tmp & 1)
+                        tmp = tmp >> 1
+                    if nbits == 1:
+                        # dead-end;  assuming the rail network is consistent,
+                        # this should match the direction the agent has come
+                        # from. But it's better to check in any case.
+                        reverse_direction = 0
+                        if direction == 0:
+                            reverse_direction = 2
+                        elif direction == 1:
+                            reverse_direction = 3
+                        elif direction == 2:
+                            reverse_direction = 0
+                        elif direction == 3:
+                            reverse_direction = 1
+                        valid_transition = self.gtfotd(
+                                            self.rail[pos[0]][pos[1]],
+                                            reverse_direction,
+                                            reverse_direction)
+                        if valid_transition:
+                            direction = reverse_direction
+                            movement = direction
+                new_position = self._new_position(pos, movement)
+                # Is it a legal move?  1) transition allows the movement in the
+                # cell,  2) the new cell is not empty (case 0),  3) the cell is
+                # free, i.e., no agent is currently in that cell
+                if self.rail[new_position[0]][new_position[1]] > 0:
+                    new_cell_isValid = True
+                else:
+                    new_cell_isValid = False
+                transition_isValid = self.gtfotd(
+                     self.rail[pos[0]][pos[1]],
+                     direction,
+                     movement)
+                cell_isFree = True
+                for j in range(self.number_of_agents):
+                    if self.agents_position[j] == new_position:
+                        cell_isFree = False
+                        break
+                if new_cell_isValid and transition_isValid and cell_isFree:
+                    # move and change direction to face the movement that was
+                    # performed
+                    self.agents_position[i] = new_position
+                    self.agents_direction[i] = movement
+                else:
+                    # the action was not valid, add penalty
+                    rewards_dict[handle] += invalid_action_penalty
+            # if agent is not in target position, add step penalty
+            if self.agents_position[i][0] == self.agents_target[i][0] and \
+               self.agents_position[i][1] == self.agents_target[i][1]:
+                self.dones[handle] = True
+            else:
+                rewards_dict[handle] += step_penalty
+        # Check for end of episode + add global reward to all rewards!
+        num_agents_in_target_position = 0
+        for i in range(self.number_of_agents):
+            if self.agents_position[i][0] == self.agents_target[i][0] and \
+               self.agents_position[i][1] == self.agents_target[i][1]:
+                num_agents_in_target_position += 1
+        if num_agents_in_target_position == self.number_of_agents:
+            self.dones["__all__"] = True
+            rewards_dict = [r+global_reward for r in rewards_dict]
+        # Reset the step actions (in case some agent doesn't 'register_action'
+        # on the next step)
+        self.actions = [0]*self.number_of_agents
+        obs_dict = {}
+        for handle in self.agents_handles:
+            obs_dict[handle] = self._get_observation_for_agent(handle)
+        return obs_dict, rewards_dict, self.dones, {}
+    def _new_position(self, position, movement):
+        if movement == 0:    # NORTH
+            return (position[0]-1, position[1])
+        elif movement == 1:  # EAST
+            return (position[0], position[1] + 1)
+        elif movement == 2:  # SOUTH
+            return (position[0]+1, position[1])
+        elif movement == 3:  # WEST
+            return (position[0], position[1] - 1)
+    def _path_exists(self, start, direction, end):
+        # BFS - Check if a path exists between the 2 nodes
+        visited = set()
+        stack = [(start, direction)]
+        while stack:
+            node = stack.pop()
+            if node[0][0] == end[0] and node[0][1] == end[1]:
+                return 1
+            if node not in visited:
+                visited.add(node)
+                moves = self.t_utils.get_transitions_from_orientation(
+                         self.rail[node[0][0]][node[0][1]], node[1])
+                for move_index in range(4):
+                    if moves[move_index]:
+                        stack.append((self._new_position(node[0], move_index),
+                                      move_index))
+                # If cell is a dead-end, append previous node with reversed
+                # orientation!
+                nbits = 0
+                tmp = self.rail[node[0][0]][node[0][1]]
+                while tmp > 0:
+                    nbits += (tmp & 1)
+                    tmp = tmp >> 1
+                if nbits == 1:
+                    stack.append((node[0], (node[1] + 2) % 4))
+        return 0
+    def _get_observation_for_agent(self, handle):
+        # TODO:
+        return None
+    def render(self):
+        # TODO:
+        pass
diff --git a/flatland/utils/ b/flatland/utils/
new file mode 100644
index 00000000..df180e70
--- /dev/null
+++ b/flatland/utils/
@@ -0,0 +1,305 @@
+The rail_env_generator module defines provides utilities to generate env
+bitmaps for the RailEnv environment.
+import numpy as np
+import random
+from flatland.core.transitions import RailEnvTransitions
+def generate_rail_from_manual_specifications(rail_spec):
+    """
+    Utility to convert a rail given by manual specification as a map of tuples
+    (cell_type, rotation), to a transition map with the correct 16-bit
+    transitions specifications.
+    Parameters
+    -------
+    rail_spec : list of list of tuples
+        List (rows) of lists (columns) of tuples, each specifying a cell for
+        the RailEnv environment as (cell_type, rotation), with rotation being
+        clock-wise and in [0, 90, 180, 270].
+    Returns
+    -------
+    numpy.ndarray of type numpy.uint16
+        The matrix with the correct 16-bit bitmaps for each cell.
+    """
+    t_utils = RailEnvTransitions()
+    height = len(rail_spec)
+    width = len(rail_spec[0])
+    rail = np.zeros((height, width), dtype=np.uint16)
+    for r in range(height):
+        for c in range(width):
+            cell = rail_spec[r][c]
+            if cell[0] < 0 or cell[0] >= len(t_utils.transitions):
+                print("ERROR - invalid cell type=", cell[0])
+                return []
+            rail[r, c] = t_utils.rotate_transition(
+                          t_utils.transitions[cell[0]], cell[1])
+    return rail
+def generate_random_rail(width, height):
+    """
+    Dummy random level generator:
+    - fill in cells at random in [width-2, height-2]
+    - keep filling cells in among the unfilled ones, such that all transitions
+      are legit;  if no cell can be filled in without violating some
+      transitions, pick one among those that can satisfy most transitions
+      (1,2,3 or 4), and delete (+mark to be re-filled) the cells that were
+      incompatible.
+    - keep trying for a total number of insertions
+      (e.g., (W-2)*(H-2)*MAX_REPETITIONS ); if no solution is found, empty the
+      board and try again from scratch.
+    - finally pad the border of the map with dead-ends to avoid border issues.
+    Dead-ends are not allowed inside the grid, only at the border; however, if
+    no cell type can be inserted in a given cell (because of the neighboring
+    transitions), deadends are allowed if they solve the problem. This was
+    found to turn most un-genereatable levels into valid ones.
+    Parameters
+    -------
+    width : int
+        The width (number of cells) of the grid to generate.
+    height : int
+        The height (number of cells) of the grid to generate.
+    Returns
+    -------
+    numpy.ndarray of type numpy.uint16
+        The matrix with the correct 16-bit bitmaps for each cell.
+    """
+    t_utils = RailEnvTransitions()
+    transitions_templates_ = []
+    for i in range(len(t_utils.transitions)-1):  # don't include dead-ends
+        all_transitions = 0
+        for dir_ in range(4):
+            trans = t_utils.get_transitions_from_orientation(
+                     t_utils.transitions[i], dir_)
+            all_transitions |= (trans[0] << 3) | \
+                               (trans[1] << 2) | \
+                               (trans[2] << 1) | \
+                               (trans[3])
+        template = [int(x) for x in bin(all_transitions)[2:]]
+        template = [0]*(4-len(template)) + template
+        # add all rotations
+        for rot in [0, 90, 180, 270]:
+            transitions_templates_.append((template,
+                                          t_utils.rotate_transition(
+                                           t_utils.transitions[i],
+                                           rot)))
+            template = [template[-1]]+template[:-1]
+    def get_matching_templates(template):
+        ret = []
+        for i in range(len(transitions_templates_)):
+            is_match = True
+            for j in range(4):
+                if template[j] >= 0 and \
+                   template[j] != transitions_templates_[i][0][j]:
+                    is_match = False
+                    break
+            if is_match:
+                ret.append(transitions_templates_[i][1])
+        return ret
+    MAX_INSERTIONS = (width-2) * (height-2) * 10
+    attempt_number = 0
+    while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH:
+        cells_to_fill = []
+        rail = []
+        for r in range(height):
+            rail.append([None]*width)
+            if r > 0 and r < height-1:
+                cells_to_fill = cells_to_fill \
+                                + [(r, c) for c in range(1, width-1)]
+        num_insertions = 0
+        while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0:
+            cell = random.sample(cells_to_fill, 1)[0]
+            cells_to_fill.remove(cell)
+            row = cell[0]
+            col = cell[1]
+            # look at its neighbors and see what are the possible transitions
+            # that can be chosen from, if any.
+            valid_template = [-1, -1, -1, -1]
+            for el in [(0, 2, (-1, 0)),
+                       (1, 3, (0, 1)),
+                       (2, 0, (1, 0)),
+                       (3, 1, (0, -1))]:  # N, E, S, W
+                neigh_trans = rail[row+el[2][0]][col+el[2][1]]
+                if neigh_trans is not None:
+                    # select transition coming from facing direction el[1] and
+                    # moving to direction el[1]
+                    max_bit = 0
+                    for k in range(4):
+                        max_bit |= \
+                         t_utils.get_transition_from_orientation_to_direction(
+                          neigh_trans, k, el[1])
+                    if max_bit:
+                        valid_template[el[0]] = 1
+                    else:
+                        valid_template[el[0]] = 0
+            possible_cell_transitions = get_matching_templates(valid_template)
+            if len(possible_cell_transitions) == 0:  # NO VALID TRANSITIONS
+                # no cell can be filled in without violating some transitions
+                # can a dead-end solve the problem?
+                if valid_template.count(1) == 1:
+                    for k in range(4):
+                        if valid_template[k] == 1:
+                            rot = 0
+                            if k == 0:
+                                rot = 180
+                            elif k == 1:
+                                rot = 270
+                            elif k == 2:
+                                rot = 0
+                            elif k == 3:
+                                rot = 90
+                            rail[row][col] = t_utils.rotate_transition(
+                                              int('0000000000100000', 2), rot)
+                            num_insertions += 1
+                            break
+                else:
+                    # can I get valid transitions by removing a single
+                    # neighboring cell?
+                    bestk = -1
+                    besttrans = []
+                    for k in range(4):
+                        tmp_template = valid_template[:]
+                        tmp_template[k] = -1
+                        possible_cell_transitions = get_matching_templates(
+                                                     tmp_template)
+                        if len(possible_cell_transitions) > len(besttrans):
+                            besttrans = possible_cell_transitions
+                            bestk = k
+                    if bestk >= 0:
+                        # Replace the corresponding cell with None, append it
+                        # to cells to fill, fill in a transition in the current
+                        # cell.
+                        replace_row = row - 1
+                        replace_col = col
+                        if bestk == 1:
+                            replace_row = row
+                            replace_col = col + 1
+                        elif bestk == 2:
+                            replace_row = row + 1
+                            replace_col = col
+                        elif bestk == 3:
+                            replace_row = row
+                            replace_col = col - 1
+                        cells_to_fill.append((replace_row, replace_col))
+                        rail[replace_row][replace_col] = None
+                        rail[row][col] = random.sample(
+                                                     besttrans, 1)[0]
+                        num_insertions += 1
+                    else:
+                        print('WARNING: still nothing!')
+                        rail[row][col] = int('0000000000000000', 2)
+                        num_insertions += 1
+                        pass
+            else:
+                rail[row][col] = random.sample(
+                                             possible_cell_transitions, 1)[0]
+                num_insertions += 1
+        if num_insertions == MAX_INSERTIONS:
+            # Failed to generate a valid level; try again for a number of times
+            attempt_number += 1
+        else:
+            break
+    if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH:
+        print('ERROR: failed to generate level')
+    # Finally pad the border of the map with dead-ends to avoid border issues;
+    # at most 1 transition in the neigh cell
+    for r in range(height):
+        # Check for transitions coming from [r][1] to WEST
+        max_bit = 0
+        neigh_trans = rail[r][1]
+        if neigh_trans is not None:
+            for k in range(4):
+                neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
+                                             & (2**4-1)
+                max_bit = max_bit | (neigh_trans_from_direction & 1)
+        if max_bit:
+            rail[r][0] = t_utils.rotate_transition(
+                           int('0000000000100000', 2), 270)
+        else:
+            rail[r][0] = int('0000000000000000', 2)
+        # Check for transitions coming from [r][-2] to EAST
+        max_bit = 0
+        neigh_trans = rail[r][-2]
+        if neigh_trans is not None:
+            for k in range(4):
+                neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
+                                             & (2**4-1)
+                max_bit = max_bit | (neigh_trans_from_direction & (1 << 2))
+        if max_bit:
+            rail[r][-1] = t_utils.rotate_transition(int('0000000000100000', 2),
+                                                    90)
+        else:
+            rail[r][-1] = int('0000000000000000', 2)
+    for c in range(width):
+        # Check for transitions coming from [1][c] to NORTH
+        max_bit = 0
+        neigh_trans = rail[1][c]
+        if neigh_trans is not None:
+            for k in range(4):
+                neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
+                                              & (2**4-1)
+                max_bit = max_bit | (neigh_trans_from_direction & (1 << 3))
+        if max_bit:
+            rail[0][c] = int('0000000000100000', 2)
+        else:
+            rail[0][c] = int('0000000000000000', 2)
+        # Check for transitions coming from [-2][c] to SOUTH
+        max_bit = 0
+        neigh_trans = rail[-2][c]
+        if neigh_trans is not None:
+            for k in range(4):
+                neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \
+                                             & (2**4-1)
+                max_bit = max_bit | (neigh_trans_from_direction & (1 << 1))
+        if max_bit:
+            rail[-1][c] = t_utils.rotate_transition(
+                            int('0000000000100000', 2), 180)
+        else:
+            rail[-1][c] = int('0000000000000000', 2)
+    # For display only, wrong levels
+    for r in range(height):
+        for c in range(width):
+            if rail[r][c] is None:
+                rail[r][c] = int('0000000000000000', 2)
+    return np.asarray(rail, dtype=np.uint16)