Commit c93115d5 authored by hagrid67's avatar hagrid67
Browse files

Add repr functions to transition_map and transitions,

add cell_neighbours_valid, is_cell_valid to transition_map
add invalid cell detection to editor and rendertools
parent f5af7c8a
Pipeline #482 passed with stage
in 2 minutes and 42 seconds
......@@ -3,6 +3,7 @@ TransitionMap and derived classes.
"""
import numpy as np
from numpy import array
from .transitions import Grid4Transitions, Grid8Transitions, RailEnvTransitions
......@@ -297,6 +298,71 @@ class GridTransitionMap(TransitionMap):
0:min(self.width, new_width)] = new_grid[0:min(self.height, new_height),
0:min(self.width, new_width)]
def is_cell_valid(self, rcPos):
cell_transition = self.grid[tuple(rcPos)]
if not self.transitions.is_valid(cell_transition):
return False
else:
return True
def cell_neighbours_valid(self, rcPos, check_this_cell=False):
"""
Check validity of cell at rcPos = tuple(row, column)
Checks that:
- surrounding cells have inbound transitions for all the
outbound transitions of this cell.
These are NOT checked - see transition.is_valid:
- all transitions have the mirror transitions (N->E <=> W->S)
- Reverse transitions (N -> S) only exist for a dead-end
- a cell contains either no dead-ends or exactly one
Returns: True (valid) or False (invalid)
"""
cell_transition = self.grid[tuple(rcPos)]
if check_this_cell:
if not self.transitions.is_valid(cell_transition):
return False
gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc]
grcPos = array(rcPos)
grcMax = self.grid.shape
binTrans = self.get_transitions(rcPos) # 16bit integer - all trans in/out
lnBinTrans = array([binTrans >> 8, binTrans & 0xff], dtype=np.uint8) # 2 x uint8
g2binTrans = np.unpackbits(lnBinTrans).reshape(4, 4) # 4x4 x uint8 binary(0,1)
# gDirIn = g2binTrans.any(axis=1) # inbound directions as boolean array (4)
gDirOut = g2binTrans.any(axis=0) # outbound directions as boolean array (4)
giDirOut = np.argwhere(gDirOut)[:, 0] # valid outbound directions as array of int
# loop over available outbound directions (indices) for rcPos
for iDirOut in giDirOut:
gdRC = gDir2dRC[iDirOut] # row,col increment
gPos2 = grcPos + gdRC # next cell in that direction
# Check the adjacent cell is within bounds
# if not, then this transition is invalid!
if np.any(gPos2 < 0):
return False
if np.any(gPos2 >= grcMax):
return False
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
t4Trans2 = self.get_transitions((*gPos2, iDirOut))
if any(t4Trans2):
continue
else:
return False
return True
def cell_repr(self, rcPos):
return self.transitions.repr(self.get_transitions(rcPos))
# TODO: GIACOMO: is it better to provide those methods with lists of cell_ids
# (most general implementation) or to make Grid-class specific methods for
# slicing over the 3 dimensions? I'd say both perhaps.
......
......@@ -4,6 +4,8 @@ derived GridTransitions class, which allows for the specification of
possible transitions over a 2D grid.
"""
import numpy as np
class Transitions:
"""
......@@ -159,6 +161,11 @@ class Grid4Transitions(Transitions):
def __init__(self, transitions):
self.transitions = transitions
self.sDirs = "NESW"
self.lsDirs = list(self.sDirs)
# row,col delta for each direction
self.gDir2dRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
def get_transitions(self, cell_transition, orientation):
"""
......@@ -216,6 +223,7 @@ class Grid4Transitions(Transitions):
(new_transitions[1] & 1) << 2 | \
(new_transitions[2] & 1) << 1 | \
(new_transitions[3] & 1)
# new_transitions = np.packbits((0, 0, 0, 0) + new_transitions) # alternative
cell_transition = (cell_transition & negmask) | (new_transitions << ((3 - orientation) * 4))
......@@ -559,6 +567,40 @@ class RailEnvTransitions(Grid4Transitions):
print("S", format(cell_transition >> (1*4) & 0xF, '04b'))
print("W", format(cell_transition >> (0*4) & 0xF, '04b'))
def repr(self, cell_transition, version=0):
"""
Provide a string representation of the cell transitions.
This class doesn't represent an individual cell,
but a way of interpreting the contents of a cell.
So using the ad hoc name repr rather than __repr__.
"""
# binary format string without leading 0b
sbinTrans = format(cell_transition, "#018b")[2:]
if version == 0:
sRepr = " ".join([
"{}:{}".format(sDir, sbinTrans[i:i+4])
for i, sDir in
zip(
range(0, len(sbinTrans), 4),
self.lsDirs # NESW
)])
return sRepr
if version == 1:
lsRepr = []
for iDirIn in range(0, 4):
sDirTrans = sbinTrans[iDirIn*4:iDirIn*4+4]
if sDirTrans == "0000":
continue
sDirsOut = [
self.lsDirs[iDirOut]
for iDirOut in range(0, 4)
if sDirTrans[iDirOut] == "1"
]
lsRepr.append(self.lsDirs[iDirIn] + ":" + "".join(sDirsOut))
return ", ".join(lsRepr)
def is_valid(self, cell_transition):
"""
Checks if a cell transition is a valid cell setup.
......@@ -578,3 +620,12 @@ class RailEnvTransitions(Grid4Transitions):
return True
return False
def has_deadend(self, cell_transition):
binDeadends = 0b0010000110000100
if cell_transition & binDeadends > 0:
return True
else:
return False
# def remove_deadends(self, cell_transition)
\ No newline at end of file
......@@ -212,11 +212,14 @@ class RailEnv(Environment):
iAgent = self.number_of_agents
self.agents_position.append(tuple(rcPos)) # ensure it's a tuple not a list
self.agents_handles.append(max(self.agents_handles + [-1]) + 1) # max(handles) + 1, starting at 0
if iDir is None:
iDir = self.pick_agent_direction(rcPos, rcTarget)
if iDir is None:
print("Error picking agent direction at pos:", rcPos)
return None
self.agents_position.append(tuple(rcPos)) # ensure it's a tuple not a list
self.agents_handles.append(max(self.agents_handles + [-1]) + 1) # max(handles) + 1, starting at 0
self.agents_direction.append(iDir)
self.agents_target.append(rcPos) # set the target to the origin initially
self.number_of_agents += 1
......
......@@ -17,6 +17,7 @@ from flatland.envs.rail_env import RailEnv, random_rail_generator
from flatland.core.env_observation_builder import TreeObsForRailEnv
import flatland.utils.rendertools as rt
from examples.play_model import Player
from flatland.envs.env_utils import mirror
class View(object):
......@@ -105,7 +106,7 @@ class JupEditor(object):
self.redraw()
def event_handler(self, wid, event):
"""Mouse motion event handler
"""Mouse motion event handler for drawing.
"""
x = event['canvasX']
y = event['canvasY']
......@@ -161,6 +162,11 @@ class JupEditor(object):
while len(rcHistory) >= 3:
rc3Cells = array(rcHistory[:3]) # the 3 cells
rcMiddle = rc3Cells[1] # the middle cell which we will update
# Save the original state of the cell
oTransrcMiddle = self.env.rail.get_transitions(rcMiddle)
sTransrcMiddle = self.env.rail.cell_repr(rcMiddle)
# get the 2 row, col deltas between the 3 cells, eg [-1,0] = North
rc2Trans = np.diff(rc3Cells, axis=0)
......@@ -181,21 +187,49 @@ class JupEditor(object):
if len(liTrans) == 2:
# Set the transition
env.rail.set_transition((*rcMiddle, liTrans[0]), liTrans[1], bTransition)
# iValCell = env.rail.transitions.set_transition(
# env.rail.grid[tuple(rcMiddle)], liTrans[0], liTrans[1], bTransition)
# Also set the reverse transition
# iValCell = env.rail.transitions.set_transition(
# iValCell,
# (liTrans[1] + 2) % 4, # use the reversed outbound transition for inbound
# (liTrans[0] + 2) % 4, # use the reversed inbound transition for outbound
# bTransition)
# Write the cell transition value back into the grid
# env.rail.grid[tuple(rcMiddle)] = iValCell
# use the reversed outbound transition for inbound
# and the reversed inbound transition for outbound
env.rail.set_transition((*rcMiddle, mirror(liTrans[1])), mirror(liTrans[0]), bTransition)
bValid = env.rail.is_cell_valid(rcMiddle)
if not bValid:
# Reset cell transition values
env.rail.grid[tuple(rcMiddle)] = oTransrcMiddle
self.log(rcMiddle, "Orig:", sTransrcMiddle, "Mod:", self.env.rail.cell_repr(rcMiddle))
rcHistory.pop(0) # remove the last-but-one
# If final cell empty, insert deadend:
if len(rcHistory) == 2 and (self.env.rail.get_transitions(rcHistory[1]) == 0):
rc2Cells = array(rcHistory[:2]) # the 2 cells
rcFinal = rc2Cells[1] # the final cell which we will update
# get the row, col delta between the 2 cells, eg [-1,0] = North
rc2Trans = np.diff(rc2Cells, axis=0)
# get the direction index for the 2 transitions
liTrans = []
for rcTrans in rc2Trans:
iTrans = np.argwhere(np.all(self.gRCTrans - rcTrans == 0, axis=1))
if len(iTrans) > 0:
iTrans = iTrans[0][0]
liTrans.append(iTrans)
# check that we have one transition
if len(liTrans) == 1:
# Set the transition as a deadend
env.rail.set_transition((*rcFinal, liTrans[0]), mirror(liTrans[0]), bTransition)
bValid = env.rail.is_cell_valid(rcMiddle)
if not bValid:
# Reset cell transition values
env.rail.grid[tuple(rcMiddle)] = oTransrcMiddle
self.log(rcMiddle, "Orig:", sTransrcMiddle, "Mod:", self.env.rail.cell_repr(rcMiddle))
rcHistory.pop(0) # remove the last-but-one
self.redraw()
bRedrawn = True
......
......@@ -208,6 +208,7 @@ class RenderTool(object):
xyDir = np.matmul(rcDir, rt.grc2xy) # agent direction in xy
xyPos = np.matmul(rcPos - rcDir / 2, rt.grc2xy) + rt.xyHalf
print("Agent:", rcPos, iDir, rcDir, xyDir, xyPos)
self.gl.scatter(*xyPos, color=color, marker="o", s=100) # agent location
xyDirLine = array([xyPos, xyPos + xyDir/2]).T # line for agent orient.
......@@ -523,6 +524,8 @@ class RenderTool(object):
# cell transition values
oCell = env.rail.get_transitions((r, c))
bCellValid = env.rail.cell_neighbours_valid((r, c))
# Special Case 7, with a single bit; terminate at center
nbits = 0
tmp = oCell
......@@ -535,6 +538,10 @@ class RenderTool(object):
# it's a dead env.
bDeadEnd = nbits == 1
if not bCellValid:
print("invalid:", r, c)
self.gl.scatter(*xyCentre, color="r", s=50)
for orientation in range(4): # ori is where we're heading
from_ori = (orientation + 2) % 4 # 0123=NESW -> 2301=SWNE
from_xy = coords[from_ori]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment