From e1e0947e1299abac425c01bab3721fa128478324 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Fri, 5 Jul 2019 21:48:34 +0200 Subject: [PATCH] refactoring transitions_map --- flatland/core/grid/grid4.py | 3 +- flatland/core/transition_map.py | 36 +++++++++++++++------- flatland/envs/grid4_generators_utils.py | 8 ++--- flatland/envs/observations.py | 16 +++++----- flatland/envs/predictions.py | 2 +- flatland/envs/rail_env.py | 4 +-- flatland/utils/editor.py | 6 ++-- flatland/utils/rendertools.py | 8 ++--- tests/test_flatland_core_transition_map.py | 35 +++++++++++++++++---- 9 files changed, 78 insertions(+), 40 deletions(-) diff --git a/flatland/core/grid/grid4.py b/flatland/core/grid/grid4.py index 714123ed..b4b5b17c 100644 --- a/flatland/core/grid/grid4.py +++ b/flatland/core/grid/grid4.py @@ -1,4 +1,5 @@ from enum import IntEnum +from typing import Type import numpy as np @@ -218,7 +219,7 @@ class Grid4Transitions(Transitions): cell_transition = value return cell_transition - def get_direction_enum(self) -> IntEnum: + def get_direction_enum(self) -> Type[Grid4TransitionsEnum]: return Grid4TransitionsEnum def has_deadend(self, cell_transition): diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index cb09a628..5e0f6cd7 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -7,6 +7,7 @@ from importlib_resources import path from numpy import array from flatland.core.grid.grid4 import Grid4Transitions +from flatland.core.transitions import Transitions class TransitionMap: @@ -110,7 +111,7 @@ class GridTransitionMap(TransitionMap): GridTransitionMap implements utility functions. """ - def __init__(self, width, height, transitions=Grid4Transitions([])): + def __init__(self, width, height, transitions: Transitions = Grid4Transitions([])): """ Builder for GridTransitionMap object. @@ -132,7 +133,25 @@ class GridTransitionMap(TransitionMap): self.grid = np.zeros((height, width), dtype=self.transitions.get_type()) - def get_transitions(self, cell_id): + def get_full_transitions(self, row, column): + """ + Returns the full transitions for the cell at (row, column) in the format transition_map's transitions. + + Parameters + ---------- + row: int + column: int + (row,column) specifies the cell in this transition map. + + Returns + ------- + self.transitions.get_type() + The cell content int the format of this map's Transitions. + + """ + return self.grid[row][column] + + def get_transitions(self, row, column, orientation): """ Return a tuple of transitions available in a cell specified by `cell_id' (e.g., a tuple of size of the maximum number of transitions, @@ -150,15 +169,10 @@ class GridTransitionMap(TransitionMap): Returns ------- tuple - List of the validity of transitions in the cell. + List of the validity of transitions in the cell as given by the maps transitions. """ - assert len(cell_id) in (2, 3), \ - 'GridTransitionMap.get_transitions() ERROR: cell_id tuple must have length 2 or 3.' - if len(cell_id) == 3: - return self.transitions.get_transitions(self.grid[cell_id[0]][cell_id[1]], cell_id[2]) - elif len(cell_id) == 2: - return self.grid[cell_id[0]][cell_id[1]] + return self.transitions.get_transitions(self.grid[row][column], orientation) def set_transitions(self, cell_id, new_transitions): """ @@ -308,7 +322,7 @@ class GridTransitionMap(TransitionMap): grcPos = array(rcPos) grcMax = self.grid.shape - binTrans = self.get_transitions(rcPos) # 16bit integer - all trans in/out + binTrans = self.get_full_transitions(*rcPos) # 16bit integer - all trans in/out lnBinTrans = array([binTrans >> 8, binTrans & 0xff], dtype=np.uint8) # 2 x uint8 g2binTrans = np.unpackbits(lnBinTrans).reshape(4, 4) # 4x4 x uint8 binary(0,1) gDirOut = g2binTrans.any(axis=0) # outbound directions as boolean array (4) @@ -328,7 +342,7 @@ class GridTransitionMap(TransitionMap): # Get the transitions out of gPos2, using iDirOut as the inbound direction # if there are no available transitions, ie (0,0,0,0), then rcPos is invalid - t4Trans2 = self.get_transitions((*gPos2, iDirOut)) + t4Trans2 = self.get_transitions(*gPos2, iDirOut) if any(t4Trans2): continue else: diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index 4b2ab8cb..dedd76b6 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -75,7 +75,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents): return 1 if node not in visited: visited.add(node) - moves = rail.get_transitions((node[0][0], node[0][1], node[1])) + moves = rail.get_transitions(node[0][0], node[0][1], node[1]) for move_index in range(4): if moves[move_index]: stack.append((get_new_position(node[0], move_index), @@ -84,7 +84,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents): # If cell is a dead-end, append previous node with reversed # orientation! nbits = 0 - tmp = rail.get_transitions((node[0][0], node[0][1])) + tmp = rail.get_full_transitions(node[0][0], node[0][1]) while tmp > 0: nbits += (tmp & 1) tmp = tmp >> 1 @@ -96,7 +96,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents): valid_positions = [] for r in range(rail.height): for c in range(rail.width): - if rail.get_transitions((r, c)) > 0: + if rail.get_full_transitions(r, c) > 0: valid_positions.append((r, c)) re_generate = True @@ -116,7 +116,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents): valid_movements = [] for direction in range(4): position = agents_position[i] - moves = rail.get_transitions((position[0], position[1], direction)) + moves = rail.get_transitions(position[0], position[1], direction) for move_index in range(4): if moves[move_index]: valid_movements.append((direction, move_index)) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 1f02d518..add983c0 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -253,7 +253,7 @@ class TreeObsForRailEnv(ObservationBuilder): if handle > len(self.env.agents): print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents)) agent = self.env.agents[handle] # TODO: handle being treated as index - possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction)) + possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) num_transitions = np.count_nonzero(possible_transitions) # Root node - current position @@ -383,8 +383,8 @@ class TreeObsForRailEnv(ObservationBuilder): last_is_target = True break - cell_transitions = self.env.rail.get_transitions((*position, direction)) - total_transitions = bin(self.env.rail.get_transitions(position)).count("1") + cell_transitions = self.env.rail.get_transitions(*position, direction) + total_transitions = bin(self.env.rail.get_full_transitions(*position)).count("1") num_transitions = np.count_nonzero(cell_transitions) exploring = False # Detect Switches that can only be used by other agents. @@ -394,7 +394,7 @@ class TreeObsForRailEnv(ObservationBuilder): if num_transitions == 1: # Check if dead-end, or if we can go forward along direction nbits = 0 - tmp = self.env.rail.get_transitions(tuple(position)) + tmp = self.env.rail.get_full_transitions(*position) while tmp > 0: nbits += (tmp & 1) tmp = tmp >> 1 @@ -469,7 +469,7 @@ class TreeObsForRailEnv(ObservationBuilder): # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation # Get the possible transitions - possible_transitions = self.env.rail.get_transitions((*position, direction)) + possible_transitions = self.env.rail.get_transitions(*position, direction) for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]: if last_is_dead_end and self.env.rail.get_transition((*position, direction), (branch_direction + 2) % 4): @@ -572,7 +572,7 @@ class GlobalObsForRailEnv(ObservationBuilder): self.rail_obs = np.zeros((self.env.height, self.env.width, 16)) for i in range(self.rail_obs.shape[0]): for j in range(self.rail_obs.shape[1]): - bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]] + bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]] bitlist = [0] * (16 - len(bitlist)) + bitlist self.rail_obs[i, j] = np.array(bitlist) @@ -630,7 +630,7 @@ class GlobalObsForRailEnvDirectionDependent(ObservationBuilder): self.rail_obs = np.zeros((self.env.height, self.env.width, 16)) for i in range(self.rail_obs.shape[0]): for j in range(self.rail_obs.shape[1]): - bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]] + bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]] bitlist = [0] * (16 - len(bitlist)) + bitlist self.rail_obs[i, j] = np.array(bitlist) @@ -701,7 +701,7 @@ class LocalObsForRailEnv(ObservationBuilder): self.env.width + 2 * self.view_radius, 16)) for i in range(self.env.height): for j in range(self.env.width): - bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]] + bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]] bitlist = [0] * (16 - len(bitlist)) + bitlist self.rail_obs[i + self.view_radius, j + self.view_radius] = np.array(bitlist) diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index d471596b..2605e84c 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -131,7 +131,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING] continue # Take shortest possible path - cell_transitions = self.env.rail.get_transitions((*agent.position, agent.direction)) + cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) new_position = None new_direction = None diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index c7050550..4e8832ec 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -322,7 +322,7 @@ class RailEnv(Environment): new_position, np.clip(new_position, [0, 0], [self.height - 1, self.width - 1])) and # check the new position has some transitions (ie is not an empty cell) - self.rail.get_transitions(new_position) > 0) + self.rail.get_full_transitions(*new_position) > 0) # If transition validity hasn't been checked yet. if transition_isValid is None: @@ -338,7 +338,7 @@ class RailEnv(Environment): def check_action(self, agent, action): transition_isValid = None - possible_transitions = self.rail.get_transitions((*agent.position, agent.direction)) + possible_transitions = self.rail.get_transitions(*agent.position, agent.direction) num_transitions = np.count_nonzero(possible_transitions) new_direction = agent.direction diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index dd0c4d4b..ea4056fa 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -494,7 +494,7 @@ class EditorModel(object): if len(lrcStroke) >= 2: # If the first cell in a stroke is empty, add a deadend to cell 0 - if self.env.rail.get_transitions(lrcStroke[0]) == 0: + if self.env.rail.get_full_transitions(*lrcStroke[0]) == 0: self.mod_rail_2cells(lrcStroke, bAddRemove, iCellToMod=0) # Add transitions for groups of 3 cells @@ -504,7 +504,7 @@ class EditorModel(object): # If final cell empty, insert deadend: if len(lrcStroke) == 2: - if self.env.rail.get_transitions(lrcStroke[1]) == 0: + if self.env.rail.get_full_transitions(*lrcStroke[1]) == 0: self.mod_rail_2cells(lrcStroke, bAddRemove, iCellToMod=1) # now empty out the final two cells from the queue @@ -752,7 +752,7 @@ class EditorModel(object): self.log(*args, **kwargs) def debug_cell(self, rcCell): - binTrans = self.env.rail.get_transitions(rcCell) + binTrans = self.env.rail.get_full_transitions(*rcCell) sbinTrans = format(binTrans, "#018b")[2:] self.debug("cell ", rcCell, diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index a148a5ae..6fa68f02 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -86,7 +86,7 @@ class RenderTool(object): for visit in lVisits: # transition for next cell - tbTrans = self.env.rail.get_transitions((*visit.rc, visit.iDir)) + tbTrans = self.env.rail.get_transitions(*visit.rc, visit.iDir) giTrans = np.where(tbTrans)[0] # RC list of transitions gTransRCAg = rt.gTransRC[giTrans] self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color) @@ -125,7 +125,7 @@ class RenderTool(object): ) """ - tbTrans = self.env.rail.get_transitions((*rcPos, iDir)) + tbTrans = self.env.rail.get_transitions(*rcPos, iDir) giTrans = np.where(tbTrans)[0] # RC list of transitions # HACK: workaround dead-end transitions @@ -459,7 +459,7 @@ class RenderTool(object): xyCentre = array([x0, y1]) + cell_size / 2 # cell transition values - oCell = env.rail.get_transitions((r, c)) + oCell = env.rail.get_full_transitions(r, c) bCellValid = env.rail.cell_neighbours_valid((r, c), check_this_cell=True) @@ -482,7 +482,7 @@ class RenderTool(object): from_ori = (orientation + 2) % 4 # 0123=NESW -> 2301=SWNE from_xy = coords[from_ori] - tMoves = env.rail.get_transitions((r, c, orientation)) + tMoves = env.rail.get_transitions(r, c, orientation) for to_ori in range(4): to_xy = coords[to_ori] diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py index 8013e912..a4142316 100644 --- a/tests/test_flatland_core_transition_map.py +++ b/tests/test_flatland_core_transition_map.py @@ -5,19 +5,42 @@ from flatland.core.transition_map import GridTransitionMap def test_grid4_get_transitions(): grid4_map = GridTransitionMap(2, 2, Grid4Transitions([])) - assert grid4_map.get_transitions((0, 0, Grid4TransitionsEnum.NORTH)) == (0, 0, 0, 0) + assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (0, 0, 0, 0) + assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0) + assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0) + assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0) + assert grid4_map.get_full_transitions(0, 0) == 0 + grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 1) - assert grid4_map.get_transitions((0, 0, Grid4TransitionsEnum.NORTH)) == (1, 0, 0, 0) + assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (1, 0, 0, 0) + assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0) + assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0) + assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0) + assert grid4_map.get_full_transitions(0, 0) == pow(2, 15) # the most significant bit is on + + grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.WEST, 1) + assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (1, 0, 0, 1) + assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0) + assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0) + assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0) + # the most significant and the fourth most significant bits are on + assert grid4_map.get_full_transitions(0, 0) == pow(2, 15) + pow(2, 12) + grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 0) - assert grid4_map.get_transitions((0, 0, Grid4TransitionsEnum.NORTH)) == (0, 0, 0, 0) + assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (0, 0, 0, 1) + assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0) + assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0) + assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0) + # the fourth most significant bits are on + assert grid4_map.get_full_transitions(0, 0) == pow(2, 12) def test_grid8_set_transitions(): grid8_map = GridTransitionMap(2, 2, Grid8Transitions([])) - assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (0, 0, 0, 0, 0, 0, 0, 0) + assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (0, 0, 0, 0, 0, 0, 0, 0) grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 1) - assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (1, 0, 0, 0, 0, 0, 0, 0) + assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (1, 0, 0, 0, 0, 0, 0, 0) grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 0) - assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (0, 0, 0, 0, 0, 0, 0, 0) + assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (0, 0, 0, 0, 0, 0, 0, 0) # TODO GridTransitionMap -- GitLab