Commit e1e0947e authored by u214892's avatar u214892
Browse files

refactoring transitions_map

parent 3383b56b
Pipeline #1356 passed with stage
in 6 minutes and 55 seconds
from enum import IntEnum
from typing import Type
import numpy as np
......@@ -218,7 +219,7 @@ class Grid4Transitions(Transitions):
cell_transition = value
return cell_transition
def get_direction_enum(self) -> IntEnum:
def get_direction_enum(self) -> Type[Grid4TransitionsEnum]:
return Grid4TransitionsEnum
def has_deadend(self, cell_transition):
......
......@@ -7,6 +7,7 @@ from importlib_resources import path
from numpy import array
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.transitions import Transitions
class TransitionMap:
......@@ -110,7 +111,7 @@ class GridTransitionMap(TransitionMap):
GridTransitionMap implements utility functions.
"""
def __init__(self, width, height, transitions=Grid4Transitions([])):
def __init__(self, width, height, transitions: Transitions = Grid4Transitions([])):
"""
Builder for GridTransitionMap object.
......@@ -132,7 +133,25 @@ class GridTransitionMap(TransitionMap):
self.grid = np.zeros((height, width), dtype=self.transitions.get_type())
def get_transitions(self, cell_id):
def get_full_transitions(self, row, column):
"""
Returns the full transitions for the cell at (row, column) in the format transition_map's transitions.
Parameters
----------
row: int
column: int
(row,column) specifies the cell in this transition map.
Returns
-------
self.transitions.get_type()
The cell content int the format of this map's Transitions.
"""
return self.grid[row][column]
def get_transitions(self, row, column, orientation):
"""
Return a tuple of transitions available in a cell specified by
`cell_id' (e.g., a tuple of size of the maximum number of transitions,
......@@ -150,15 +169,10 @@ class GridTransitionMap(TransitionMap):
Returns
-------
tuple
List of the validity of transitions in the cell.
List of the validity of transitions in the cell as given by the maps transitions.
"""
assert len(cell_id) in (2, 3), \
'GridTransitionMap.get_transitions() ERROR: cell_id tuple must have length 2 or 3.'
if len(cell_id) == 3:
return self.transitions.get_transitions(self.grid[cell_id[0]][cell_id[1]], cell_id[2])
elif len(cell_id) == 2:
return self.grid[cell_id[0]][cell_id[1]]
return self.transitions.get_transitions(self.grid[row][column], orientation)
def set_transitions(self, cell_id, new_transitions):
"""
......@@ -308,7 +322,7 @@ class GridTransitionMap(TransitionMap):
grcPos = array(rcPos)
grcMax = self.grid.shape
binTrans = self.get_transitions(rcPos) # 16bit integer - all trans in/out
binTrans = self.get_full_transitions(*rcPos) # 16bit integer - all trans in/out
lnBinTrans = array([binTrans >> 8, binTrans & 0xff], dtype=np.uint8) # 2 x uint8
g2binTrans = np.unpackbits(lnBinTrans).reshape(4, 4) # 4x4 x uint8 binary(0,1)
gDirOut = g2binTrans.any(axis=0) # outbound directions as boolean array (4)
......@@ -328,7 +342,7 @@ class GridTransitionMap(TransitionMap):
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
t4Trans2 = self.get_transitions((*gPos2, iDirOut))
t4Trans2 = self.get_transitions(*gPos2, iDirOut)
if any(t4Trans2):
continue
else:
......
......@@ -75,7 +75,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
return 1
if node not in visited:
visited.add(node)
moves = rail.get_transitions((node[0][0], node[0][1], node[1]))
moves = rail.get_transitions(node[0][0], node[0][1], node[1])
for move_index in range(4):
if moves[move_index]:
stack.append((get_new_position(node[0], move_index),
......@@ -84,7 +84,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
# If cell is a dead-end, append previous node with reversed
# orientation!
nbits = 0
tmp = rail.get_transitions((node[0][0], node[0][1]))
tmp = rail.get_full_transitions(node[0][0], node[0][1])
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
......@@ -96,7 +96,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
valid_positions = []
for r in range(rail.height):
for c in range(rail.width):
if rail.get_transitions((r, c)) > 0:
if rail.get_full_transitions(r, c) > 0:
valid_positions.append((r, c))
re_generate = True
......@@ -116,7 +116,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
valid_movements = []
for direction in range(4):
position = agents_position[i]
moves = rail.get_transitions((position[0], position[1], direction))
moves = rail.get_transitions(position[0], position[1], direction)
for move_index in range(4):
if moves[move_index]:
valid_movements.append((direction, move_index))
......
......@@ -253,7 +253,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if handle > len(self.env.agents):
print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
agent = self.env.agents[handle] # TODO: handle being treated as index
possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
# Root node - current position
......@@ -383,8 +383,8 @@ class TreeObsForRailEnv(ObservationBuilder):
last_is_target = True
break
cell_transitions = self.env.rail.get_transitions((*position, direction))
total_transitions = bin(self.env.rail.get_transitions(position)).count("1")
cell_transitions = self.env.rail.get_transitions(*position, direction)
total_transitions = bin(self.env.rail.get_full_transitions(*position)).count("1")
num_transitions = np.count_nonzero(cell_transitions)
exploring = False
# Detect Switches that can only be used by other agents.
......@@ -394,7 +394,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if num_transitions == 1:
# Check if dead-end, or if we can go forward along direction
nbits = 0
tmp = self.env.rail.get_transitions(tuple(position))
tmp = self.env.rail.get_full_transitions(*position)
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
......@@ -469,7 +469,7 @@ class TreeObsForRailEnv(ObservationBuilder):
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
# Get the possible transitions
possible_transitions = self.env.rail.get_transitions((*position, direction))
possible_transitions = self.env.rail.get_transitions(*position, direction)
for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
if last_is_dead_end and self.env.rail.get_transition((*position, direction),
(branch_direction + 2) % 4):
......@@ -572,7 +572,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
for i in range(self.rail_obs.shape[0]):
for j in range(self.rail_obs.shape[1]):
bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]]
bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
bitlist = [0] * (16 - len(bitlist)) + bitlist
self.rail_obs[i, j] = np.array(bitlist)
......@@ -630,7 +630,7 @@ class GlobalObsForRailEnvDirectionDependent(ObservationBuilder):
self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
for i in range(self.rail_obs.shape[0]):
for j in range(self.rail_obs.shape[1]):
bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]]
bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
bitlist = [0] * (16 - len(bitlist)) + bitlist
self.rail_obs[i, j] = np.array(bitlist)
......@@ -701,7 +701,7 @@ class LocalObsForRailEnv(ObservationBuilder):
self.env.width + 2 * self.view_radius, 16))
for i in range(self.env.height):
for j in range(self.env.width):
bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]]
bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
bitlist = [0] * (16 - len(bitlist)) + bitlist
self.rail_obs[i + self.view_radius, j + self.view_radius] = np.array(bitlist)
......
......@@ -131,7 +131,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING]
continue
# Take shortest possible path
cell_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
new_position = None
new_direction = None
......
......@@ -322,7 +322,7 @@ class RailEnv(Environment):
new_position,
np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
and # check the new position has some transitions (ie is not an empty cell)
self.rail.get_transitions(new_position) > 0)
self.rail.get_full_transitions(*new_position) > 0)
# If transition validity hasn't been checked yet.
if transition_isValid is None:
......@@ -338,7 +338,7 @@ class RailEnv(Environment):
def check_action(self, agent, action):
transition_isValid = None
possible_transitions = self.rail.get_transitions((*agent.position, agent.direction))
possible_transitions = self.rail.get_transitions(*agent.position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
new_direction = agent.direction
......
......@@ -494,7 +494,7 @@ class EditorModel(object):
if len(lrcStroke) >= 2:
# If the first cell in a stroke is empty, add a deadend to cell 0
if self.env.rail.get_transitions(lrcStroke[0]) == 0:
if self.env.rail.get_full_transitions(*lrcStroke[0]) == 0:
self.mod_rail_2cells(lrcStroke, bAddRemove, iCellToMod=0)
# Add transitions for groups of 3 cells
......@@ -504,7 +504,7 @@ class EditorModel(object):
# If final cell empty, insert deadend:
if len(lrcStroke) == 2:
if self.env.rail.get_transitions(lrcStroke[1]) == 0:
if self.env.rail.get_full_transitions(*lrcStroke[1]) == 0:
self.mod_rail_2cells(lrcStroke, bAddRemove, iCellToMod=1)
# now empty out the final two cells from the queue
......@@ -752,7 +752,7 @@ class EditorModel(object):
self.log(*args, **kwargs)
def debug_cell(self, rcCell):
binTrans = self.env.rail.get_transitions(rcCell)
binTrans = self.env.rail.get_full_transitions(*rcCell)
sbinTrans = format(binTrans, "#018b")[2:]
self.debug("cell ",
rcCell,
......
......@@ -86,7 +86,7 @@ class RenderTool(object):
for visit in lVisits:
# transition for next cell
tbTrans = self.env.rail.get_transitions((*visit.rc, visit.iDir))
tbTrans = self.env.rail.get_transitions(*visit.rc, visit.iDir)
giTrans = np.where(tbTrans)[0] # RC list of transitions
gTransRCAg = rt.gTransRC[giTrans]
self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color)
......@@ -125,7 +125,7 @@ class RenderTool(object):
)
"""
tbTrans = self.env.rail.get_transitions((*rcPos, iDir))
tbTrans = self.env.rail.get_transitions(*rcPos, iDir)
giTrans = np.where(tbTrans)[0] # RC list of transitions
# HACK: workaround dead-end transitions
......@@ -459,7 +459,7 @@ class RenderTool(object):
xyCentre = array([x0, y1]) + cell_size / 2
# cell transition values
oCell = env.rail.get_transitions((r, c))
oCell = env.rail.get_full_transitions(r, c)
bCellValid = env.rail.cell_neighbours_valid((r, c), check_this_cell=True)
......@@ -482,7 +482,7 @@ class RenderTool(object):
from_ori = (orientation + 2) % 4 # 0123=NESW -> 2301=SWNE
from_xy = coords[from_ori]
tMoves = env.rail.get_transitions((r, c, orientation))
tMoves = env.rail.get_transitions(r, c, orientation)
for to_ori in range(4):
to_xy = coords[to_ori]
......
......@@ -5,19 +5,42 @@ from flatland.core.transition_map import GridTransitionMap
def test_grid4_get_transitions():
grid4_map = GridTransitionMap(2, 2, Grid4Transitions([]))
assert grid4_map.get_transitions((0, 0, Grid4TransitionsEnum.NORTH)) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
assert grid4_map.get_full_transitions(0, 0) == 0
grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 1)
assert grid4_map.get_transitions((0, 0, Grid4TransitionsEnum.NORTH)) == (1, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (1, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
assert grid4_map.get_full_transitions(0, 0) == pow(2, 15) # the most significant bit is on
grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.WEST, 1)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (1, 0, 0, 1)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
# the most significant and the fourth most significant bits are on
assert grid4_map.get_full_transitions(0, 0) == pow(2, 15) + pow(2, 12)
grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 0)
assert grid4_map.get_transitions((0, 0, Grid4TransitionsEnum.NORTH)) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (0, 0, 0, 1)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
# the fourth most significant bits are on
assert grid4_map.get_full_transitions(0, 0) == pow(2, 12)
def test_grid8_set_transitions():
grid8_map = GridTransitionMap(2, 2, Grid8Transitions([]))
assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (0, 0, 0, 0, 0, 0, 0, 0)
assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (0, 0, 0, 0, 0, 0, 0, 0)
grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 1)
assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (1, 0, 0, 0, 0, 0, 0, 0)
assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (1, 0, 0, 0, 0, 0, 0, 0)
grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 0)
assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (0, 0, 0, 0, 0, 0, 0, 0)
assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (0, 0, 0, 0, 0, 0, 0, 0)
# TODO GridTransitionMap
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment