Skip to content
Snippets Groups Projects
Commit 3b46bd85 authored by Christian Eichenberger's avatar Christian Eichenberger :badminton:
Browse files

Merge branch '62-unit-test-coverage-and-code-cleanup' into 'master'

refactoring transitions_map

See merge request !96
parents 82c15121 34e4611a
No related branches found
No related tags found
No related merge requests found
from enum import IntEnum
from typing import Type
import numpy as np
......@@ -218,7 +219,7 @@ class Grid4Transitions(Transitions):
cell_transition = value
return cell_transition
def get_direction_enum(self) -> IntEnum:
def get_direction_enum(self) -> Type[Grid4TransitionsEnum]:
return Grid4TransitionsEnum
def has_deadend(self, cell_transition):
......
......@@ -7,6 +7,7 @@ from importlib_resources import path
from numpy import array
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.transitions import Transitions
class TransitionMap:
......@@ -110,7 +111,7 @@ class GridTransitionMap(TransitionMap):
GridTransitionMap implements utility functions.
"""
def __init__(self, width, height, transitions=Grid4Transitions([])):
def __init__(self, width, height, transitions: Transitions = Grid4Transitions([])):
"""
Builder for GridTransitionMap object.
......@@ -132,7 +133,25 @@ class GridTransitionMap(TransitionMap):
self.grid = np.zeros((height, width), dtype=self.transitions.get_type())
def get_transitions(self, cell_id):
def get_full_transitions(self, row, column):
"""
Returns the full transitions for the cell at (row, column) in the format transition_map's transitions.
Parameters
----------
row: int
column: int
(row,column) specifies the cell in this transition map.
Returns
-------
self.transitions.get_type()
The cell content int the format of this map's Transitions.
"""
return self.grid[row][column]
def get_transitions(self, row, column, orientation):
"""
Return a tuple of transitions available in a cell specified by
`cell_id' (e.g., a tuple of size of the maximum number of transitions,
......@@ -150,15 +169,10 @@ class GridTransitionMap(TransitionMap):
Returns
-------
tuple
List of the validity of transitions in the cell.
List of the validity of transitions in the cell as given by the maps transitions.
"""
assert len(cell_id) in (2, 3), \
'GridTransitionMap.get_transitions() ERROR: cell_id tuple must have length 2 or 3.'
if len(cell_id) == 3:
return self.transitions.get_transitions(self.grid[cell_id[0]][cell_id[1]], cell_id[2])
elif len(cell_id) == 2:
return self.grid[cell_id[0]][cell_id[1]]
return self.transitions.get_transitions(self.grid[row][column], orientation)
def set_transitions(self, cell_id, new_transitions):
"""
......@@ -308,7 +322,7 @@ class GridTransitionMap(TransitionMap):
grcPos = array(rcPos)
grcMax = self.grid.shape
binTrans = self.get_transitions(rcPos) # 16bit integer - all trans in/out
binTrans = self.get_full_transitions(*rcPos) # 16bit integer - all trans in/out
lnBinTrans = array([binTrans >> 8, binTrans & 0xff], dtype=np.uint8) # 2 x uint8
g2binTrans = np.unpackbits(lnBinTrans).reshape(4, 4) # 4x4 x uint8 binary(0,1)
gDirOut = g2binTrans.any(axis=0) # outbound directions as boolean array (4)
......@@ -328,7 +342,7 @@ class GridTransitionMap(TransitionMap):
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
t4Trans2 = self.get_transitions((*gPos2, iDirOut))
t4Trans2 = self.get_transitions(*gPos2, iDirOut)
if any(t4Trans2):
continue
else:
......
......@@ -75,7 +75,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
return 1
if node not in visited:
visited.add(node)
moves = rail.get_transitions((node[0][0], node[0][1], node[1]))
moves = rail.get_transitions(node[0][0], node[0][1], node[1])
for move_index in range(4):
if moves[move_index]:
stack.append((get_new_position(node[0], move_index),
......@@ -84,7 +84,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
# If cell is a dead-end, append previous node with reversed
# orientation!
nbits = 0
tmp = rail.get_transitions((node[0][0], node[0][1]))
tmp = rail.get_full_transitions(node[0][0], node[0][1])
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
......@@ -96,7 +96,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
valid_positions = []
for r in range(rail.height):
for c in range(rail.width):
if rail.get_transitions((r, c)) > 0:
if rail.get_full_transitions(r, c) > 0:
valid_positions.append((r, c))
re_generate = True
......@@ -116,7 +116,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
valid_movements = []
for direction in range(4):
position = agents_position[i]
moves = rail.get_transitions((position[0], position[1], direction))
moves = rail.get_transitions(position[0], position[1], direction)
for move_index in range(4):
if moves[move_index]:
valid_movements.append((direction, move_index))
......
......@@ -252,7 +252,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if handle > len(self.env.agents):
print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
agent = self.env.agents[handle] # TODO: handle being treated as index
possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
# Root node - current position
......@@ -382,8 +382,8 @@ class TreeObsForRailEnv(ObservationBuilder):
last_is_target = True
break
cell_transitions = self.env.rail.get_transitions((*position, direction))
total_transitions = bin(self.env.rail.get_transitions(position)).count("1")
cell_transitions = self.env.rail.get_transitions(*position, direction)
total_transitions = bin(self.env.rail.get_full_transitions(*position)).count("1")
num_transitions = np.count_nonzero(cell_transitions)
exploring = False
# Detect Switches that can only be used by other agents.
......@@ -393,7 +393,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if num_transitions == 1:
# Check if dead-end, or if we can go forward along direction
nbits = 0
tmp = self.env.rail.get_transitions(tuple(position))
tmp = self.env.rail.get_full_transitions(*position)
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
......@@ -468,7 +468,7 @@ class TreeObsForRailEnv(ObservationBuilder):
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
# Get the possible transitions
possible_transitions = self.env.rail.get_transitions((*position, direction))
possible_transitions = self.env.rail.get_transitions(*position, direction)
for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
if last_is_dead_end and self.env.rail.get_transition((*position, direction),
(branch_direction + 2) % 4):
......@@ -571,7 +571,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
for i in range(self.rail_obs.shape[0]):
for j in range(self.rail_obs.shape[1]):
bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]]
bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
bitlist = [0] * (16 - len(bitlist)) + bitlist
self.rail_obs[i, j] = np.array(bitlist)
......@@ -629,7 +629,7 @@ class GlobalObsForRailEnvDirectionDependent(ObservationBuilder):
self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
for i in range(self.rail_obs.shape[0]):
for j in range(self.rail_obs.shape[1]):
bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]]
bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
bitlist = [0] * (16 - len(bitlist)) + bitlist
self.rail_obs[i, j] = np.array(bitlist)
......@@ -700,7 +700,7 @@ class LocalObsForRailEnv(ObservationBuilder):
self.env.width + 2 * self.view_radius, 16))
for i in range(self.env.height):
for j in range(self.env.width):
bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]]
bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
bitlist = [0] * (16 - len(bitlist)) + bitlist
self.rail_obs[i + self.view_radius, j + self.view_radius] = np.array(bitlist)
......
......@@ -131,7 +131,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING]
continue
# Take shortest possible path
cell_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
new_position = None
new_direction = None
......
......@@ -326,7 +326,7 @@ class RailEnv(Environment):
new_position,
np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
and # check the new position has some transitions (ie is not an empty cell)
self.rail.get_transitions(new_position) > 0)
self.rail.get_full_transitions(*new_position) > 0)
# If transition validity hasn't been checked yet.
if transition_valid is None:
......@@ -342,7 +342,7 @@ class RailEnv(Environment):
def check_action(self, agent, action):
transition_valid = None
possible_transitions = self.rail.get_transitions((*agent.position, agent.direction))
possible_transitions = self.rail.get_transitions(*agent.position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
new_direction = agent.direction
......
......@@ -494,7 +494,7 @@ class EditorModel(object):
if len(lrcStroke) >= 2:
# If the first cell in a stroke is empty, add a deadend to cell 0
if self.env.rail.get_transitions(lrcStroke[0]) == 0:
if self.env.rail.get_full_transitions(*lrcStroke[0]) == 0:
self.mod_rail_2cells(lrcStroke, bAddRemove, iCellToMod=0)
# Add transitions for groups of 3 cells
......@@ -504,7 +504,7 @@ class EditorModel(object):
# If final cell empty, insert deadend:
if len(lrcStroke) == 2:
if self.env.rail.get_transitions(lrcStroke[1]) == 0:
if self.env.rail.get_full_transitions(*lrcStroke[1]) == 0:
self.mod_rail_2cells(lrcStroke, bAddRemove, iCellToMod=1)
# now empty out the final two cells from the queue
......@@ -752,7 +752,7 @@ class EditorModel(object):
self.log(*args, **kwargs)
def debug_cell(self, rcCell):
binTrans = self.env.rail.get_transitions(rcCell)
binTrans = self.env.rail.get_full_transitions(*rcCell)
sbinTrans = format(binTrans, "#018b")[2:]
self.debug("cell ",
rcCell,
......
......@@ -86,7 +86,7 @@ class RenderTool(object):
for visit in lVisits:
# transition for next cell
tbTrans = self.env.rail.get_transitions((*visit.rc, visit.iDir))
tbTrans = self.env.rail.get_transitions(*visit.rc, visit.iDir)
giTrans = np.where(tbTrans)[0] # RC list of transitions
gTransRCAg = rt.gTransRC[giTrans]
self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color)
......@@ -125,7 +125,7 @@ class RenderTool(object):
)
"""
tbTrans = self.env.rail.get_transitions((*rcPos, iDir))
tbTrans = self.env.rail.get_transitions(*rcPos, iDir)
giTrans = np.where(tbTrans)[0] # RC list of transitions
# HACK: workaround dead-end transitions
......@@ -459,7 +459,7 @@ class RenderTool(object):
xyCentre = array([x0, y1]) + cell_size / 2
# cell transition values
oCell = env.rail.get_transitions((r, c))
oCell = env.rail.get_full_transitions(r, c)
bCellValid = env.rail.cell_neighbours_valid((r, c), check_this_cell=True)
......@@ -482,7 +482,7 @@ class RenderTool(object):
from_ori = (orientation + 2) % 4 # 0123=NESW -> 2301=SWNE
from_xy = coords[from_ori]
tMoves = env.rail.get_transitions((r, c, orientation))
tMoves = env.rail.get_transitions(r, c, orientation)
for to_ori in range(4):
to_xy = coords[to_ori]
......
......@@ -5,19 +5,42 @@ from flatland.core.transition_map import GridTransitionMap
def test_grid4_get_transitions():
grid4_map = GridTransitionMap(2, 2, Grid4Transitions([]))
assert grid4_map.get_transitions((0, 0, Grid4TransitionsEnum.NORTH)) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
assert grid4_map.get_full_transitions(0, 0) == 0
grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 1)
assert grid4_map.get_transitions((0, 0, Grid4TransitionsEnum.NORTH)) == (1, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (1, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
assert grid4_map.get_full_transitions(0, 0) == pow(2, 15) # the most significant bit is on
grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.WEST, 1)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (1, 0, 0, 1)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
# the most significant and the fourth most significant bits are on
assert grid4_map.get_full_transitions(0, 0) == pow(2, 15) + pow(2, 12)
grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 0)
assert grid4_map.get_transitions((0, 0, Grid4TransitionsEnum.NORTH)) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.NORTH) == (0, 0, 0, 1)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.EAST) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.SOUTH) == (0, 0, 0, 0)
assert grid4_map.get_transitions(0, 0, Grid4TransitionsEnum.WEST) == (0, 0, 0, 0)
# the fourth most significant bits are on
assert grid4_map.get_full_transitions(0, 0) == pow(2, 12)
def test_grid8_set_transitions():
grid8_map = GridTransitionMap(2, 2, Grid8Transitions([]))
assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (0, 0, 0, 0, 0, 0, 0, 0)
assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (0, 0, 0, 0, 0, 0, 0, 0)
grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 1)
assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (1, 0, 0, 0, 0, 0, 0, 0)
assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (1, 0, 0, 0, 0, 0, 0, 0)
grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 0)
assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (0, 0, 0, 0, 0, 0, 0, 0)
assert grid8_map.get_transitions(0, 0, Grid8TransitionsEnum.NORTH) == (0, 0, 0, 0, 0, 0, 0, 0)
# TODO GridTransitionMap
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment