diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index bd6e79a4fb5d7ea544619ebb00899ec863b106d1..bce795e4bdff3829c7c8e7ab7474701856f9a239 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -3,6 +3,7 @@ TransitionMap and derived classes. """ import numpy as np +from numpy import array from .transitions import Grid4Transitions, Grid8Transitions, RailEnvTransitions @@ -297,6 +298,71 @@ class GridTransitionMap(TransitionMap): 0:min(self.width, new_width)] = new_grid[0:min(self.height, new_height), 0:min(self.width, new_width)] + def is_cell_valid(self, rcPos): + cell_transition = self.grid[tuple(rcPos)] + + if not self.transitions.is_valid(cell_transition): + return False + else: + return True + + def cell_neighbours_valid(self, rcPos, check_this_cell=False): + """ + Check validity of cell at rcPos = tuple(row, column) + Checks that: + - surrounding cells have inbound transitions for all the + outbound transitions of this cell. + + These are NOT checked - see transition.is_valid: + - all transitions have the mirror transitions (N->E <=> W->S) + - Reverse transitions (N -> S) only exist for a dead-end + - a cell contains either no dead-ends or exactly one + + Returns: True (valid) or False (invalid) + + """ + cell_transition = self.grid[tuple(rcPos)] + + if check_this_cell: + if not self.transitions.is_valid(cell_transition): + return False + + gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc] + grcPos = array(rcPos) + grcMax = self.grid.shape + + binTrans = self.get_transitions(rcPos) # 16bit integer - all trans in/out + lnBinTrans = array([binTrans >> 8, binTrans & 0xff], dtype=np.uint8) # 2 x uint8 + g2binTrans = np.unpackbits(lnBinTrans).reshape(4, 4) # 4x4 x uint8 binary(0,1) + # gDirIn = g2binTrans.any(axis=1) # inbound directions as boolean array (4) + gDirOut = g2binTrans.any(axis=0) # outbound directions as boolean array (4) + giDirOut = np.argwhere(gDirOut)[:, 0] # valid outbound directions as array of int + + # loop over available outbound directions (indices) for rcPos + for iDirOut in giDirOut: + gdRC = gDir2dRC[iDirOut] # row,col increment + gPos2 = grcPos + gdRC # next cell in that direction + + # Check the adjacent cell is within bounds + # if not, then this transition is invalid! + if np.any(gPos2 < 0): + return False + if np.any(gPos2 >= grcMax): + return False + + # Get the transitions out of gPos2, using iDirOut as the inbound direction + # if there are no available transitions, ie (0,0,0,0), then rcPos is invalid + t4Trans2 = self.get_transitions((*gPos2, iDirOut)) + if any(t4Trans2): + continue + else: + return False + + return True + + def cell_repr(self, rcPos): + return self.transitions.repr(self.get_transitions(rcPos)) + # TODO: GIACOMO: is it better to provide those methods with lists of cell_ids # (most general implementation) or to make Grid-class specific methods for # slicing over the 3 dimensions? I'd say both perhaps. diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py index cdc657cf0e10fc41e0a2bbec465bf39979e04819..4b1874a880a086324f7da8aa48e31529ca6985a3 100644 --- a/flatland/core/transitions.py +++ b/flatland/core/transitions.py @@ -4,6 +4,8 @@ derived GridTransitions class, which allows for the specification of possible transitions over a 2D grid. """ +import numpy as np + class Transitions: """ @@ -159,6 +161,11 @@ class Grid4Transitions(Transitions): def __init__(self, transitions): self.transitions = transitions + self.sDirs = "NESW" + self.lsDirs = list(self.sDirs) + + # row,col delta for each direction + self.gDir2dRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]]) def get_transitions(self, cell_transition, orientation): """ @@ -216,6 +223,7 @@ class Grid4Transitions(Transitions): (new_transitions[1] & 1) << 2 | \ (new_transitions[2] & 1) << 1 | \ (new_transitions[3] & 1) + # new_transitions = np.packbits((0, 0, 0, 0) + new_transitions) # alternative cell_transition = (cell_transition & negmask) | (new_transitions << ((3 - orientation) * 4)) @@ -559,6 +567,40 @@ class RailEnvTransitions(Grid4Transitions): print("S", format(cell_transition >> (1*4) & 0xF, '04b')) print("W", format(cell_transition >> (0*4) & 0xF, '04b')) + def repr(self, cell_transition, version=0): + """ + Provide a string representation of the cell transitions. + This class doesn't represent an individual cell, + but a way of interpreting the contents of a cell. + So using the ad hoc name repr rather than __repr__. + """ + # binary format string without leading 0b + sbinTrans = format(cell_transition, "#018b")[2:] + if version == 0: + sRepr = " ".join([ + "{}:{}".format(sDir, sbinTrans[i:i+4]) + for i, sDir in + zip( + range(0, len(sbinTrans), 4), + self.lsDirs # NESW + )]) + return sRepr + + if version == 1: + lsRepr = [] + for iDirIn in range(0, 4): + sDirTrans = sbinTrans[iDirIn*4:iDirIn*4+4] + if sDirTrans == "0000": + continue + sDirsOut = [ + self.lsDirs[iDirOut] + for iDirOut in range(0, 4) + if sDirTrans[iDirOut] == "1" + ] + lsRepr.append(self.lsDirs[iDirIn] + ":" + "".join(sDirsOut)) + + return ", ".join(lsRepr) + def is_valid(self, cell_transition): """ Checks if a cell transition is a valid cell setup. @@ -578,3 +620,12 @@ class RailEnvTransitions(Grid4Transitions): return True return False + def has_deadend(self, cell_transition): + binDeadends = 0b0010000110000100 + if cell_transition & binDeadends > 0: + return True + else: + return False + + # def remove_deadends(self, cell_transition) + \ No newline at end of file diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index d82f1703e44da8c01034c0920d98fc013639b6bd..d8f1af4724c33772dd2b12b28f4d9e392a87b728 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -212,11 +212,14 @@ class RailEnv(Environment): iAgent = self.number_of_agents - self.agents_position.append(tuple(rcPos)) # ensure it's a tuple not a list - self.agents_handles.append(max(self.agents_handles + [-1]) + 1) # max(handles) + 1, starting at 0 - if iDir is None: iDir = self.pick_agent_direction(rcPos, rcTarget) + if iDir is None: + print("Error picking agent direction at pos:", rcPos) + return None + + self.agents_position.append(tuple(rcPos)) # ensure it's a tuple not a list + self.agents_handles.append(max(self.agents_handles + [-1]) + 1) # max(handles) + 1, starting at 0 self.agents_direction.append(iDir) self.agents_target.append(rcPos) # set the target to the origin initially self.number_of_agents += 1 diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index b7b09b771ba49630d6a90a8f517b07dbbea747fc..a09b131eb6e72ac3a4445b65da4d3cbca4b4d6c1 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -17,6 +17,7 @@ from flatland.envs.rail_env import RailEnv, random_rail_generator from flatland.core.env_observation_builder import TreeObsForRailEnv import flatland.utils.rendertools as rt from examples.play_model import Player +from flatland.envs.env_utils import mirror class View(object): @@ -105,7 +106,7 @@ class JupEditor(object): self.redraw() def event_handler(self, wid, event): - """Mouse motion event handler + """Mouse motion event handler for drawing. """ x = event['canvasX'] y = event['canvasY'] @@ -161,6 +162,11 @@ class JupEditor(object): while len(rcHistory) >= 3: rc3Cells = array(rcHistory[:3]) # the 3 cells rcMiddle = rc3Cells[1] # the middle cell which we will update + + # Save the original state of the cell + oTransrcMiddle = self.env.rail.get_transitions(rcMiddle) + sTransrcMiddle = self.env.rail.cell_repr(rcMiddle) + # get the 2 row, col deltas between the 3 cells, eg [-1,0] = North rc2Trans = np.diff(rc3Cells, axis=0) @@ -181,21 +187,49 @@ class JupEditor(object): if len(liTrans) == 2: # Set the transition env.rail.set_transition((*rcMiddle, liTrans[0]), liTrans[1], bTransition) - # iValCell = env.rail.transitions.set_transition( - # env.rail.grid[tuple(rcMiddle)], liTrans[0], liTrans[1], bTransition) # Also set the reverse transition - # iValCell = env.rail.transitions.set_transition( - # iValCell, - # (liTrans[1] + 2) % 4, # use the reversed outbound transition for inbound - # (liTrans[0] + 2) % 4, # use the reversed inbound transition for outbound - # bTransition) - - # Write the cell transition value back into the grid - # env.rail.grid[tuple(rcMiddle)] = iValCell - + # use the reversed outbound transition for inbound + # and the reversed inbound transition for outbound + env.rail.set_transition((*rcMiddle, mirror(liTrans[1])), mirror(liTrans[0]), bTransition) + + bValid = env.rail.is_cell_valid(rcMiddle) + if not bValid: + # Reset cell transition values + env.rail.grid[tuple(rcMiddle)] = oTransrcMiddle + + self.log(rcMiddle, "Orig:", sTransrcMiddle, "Mod:", self.env.rail.cell_repr(rcMiddle)) rcHistory.pop(0) # remove the last-but-one - + + # If final cell empty, insert deadend: + if len(rcHistory) == 2 and (self.env.rail.get_transitions(rcHistory[1]) == 0): + rc2Cells = array(rcHistory[:2]) # the 2 cells + rcFinal = rc2Cells[1] # the final cell which we will update + + # get the row, col delta between the 2 cells, eg [-1,0] = North + rc2Trans = np.diff(rc2Cells, axis=0) + + # get the direction index for the 2 transitions + liTrans = [] + for rcTrans in rc2Trans: + iTrans = np.argwhere(np.all(self.gRCTrans - rcTrans == 0, axis=1)) + if len(iTrans) > 0: + iTrans = iTrans[0][0] + liTrans.append(iTrans) + + # check that we have one transition + if len(liTrans) == 1: + # Set the transition as a deadend + env.rail.set_transition((*rcFinal, liTrans[0]), mirror(liTrans[0]), bTransition) + + bValid = env.rail.is_cell_valid(rcMiddle) + if not bValid: + # Reset cell transition values + env.rail.grid[tuple(rcMiddle)] = oTransrcMiddle + + self.log(rcMiddle, "Orig:", sTransrcMiddle, "Mod:", self.env.rail.cell_repr(rcMiddle)) + rcHistory.pop(0) # remove the last-but-one + self.redraw() bRedrawn = True diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 6c5a17555beb95ce80d023d0c10376650f06ec3b..29a337239fc7d4df737b4482c81ceabff4600e7e 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -208,6 +208,7 @@ class RenderTool(object): xyDir = np.matmul(rcDir, rt.grc2xy) # agent direction in xy xyPos = np.matmul(rcPos - rcDir / 2, rt.grc2xy) + rt.xyHalf + print("Agent:", rcPos, iDir, rcDir, xyDir, xyPos) self.gl.scatter(*xyPos, color=color, marker="o", s=100) # agent location xyDirLine = array([xyPos, xyPos + xyDir/2]).T # line for agent orient. @@ -523,6 +524,8 @@ class RenderTool(object): # cell transition values oCell = env.rail.get_transitions((r, c)) + bCellValid = env.rail.cell_neighbours_valid((r, c)) + # Special Case 7, with a single bit; terminate at center nbits = 0 tmp = oCell @@ -535,6 +538,10 @@ class RenderTool(object): # it's a dead env. bDeadEnd = nbits == 1 + if not bCellValid: + print("invalid:", r, c) + self.gl.scatter(*xyCentre, color="r", s=50) + for orientation in range(4): # ori is where we're heading from_ori = (orientation + 2) % 4 # 0123=NESW -> 2301=SWNE from_xy = coords[from_ori]