From ef60e4ce5df50ea490bdf9ae0363d960e9cc8fb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mattias=20Ljungstr=C3=B6m?= <ml@mljx.io> Date: Tue, 9 Apr 2019 19:55:28 +0200 Subject: [PATCH] refactor GridTransition into Grid4, Grid8Transition. --- flatland/core/env.py | 12 +- flatland/core/transitions.py | 386 +++++++++++++++++---------- flatland/utils/rail_env_generator.py | 8 +- flatland/utils/rendertools.py | 7 +- tests/test_environments.py | 22 +- tests/test_transitions.py | 32 +-- 6 files changed, 278 insertions(+), 189 deletions(-) diff --git a/flatland/core/env.py b/flatland/core/env.py index 2ecee63..bf6df54 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -144,9 +144,7 @@ class RailEnv: self.agents_handles = list(range(self.number_of_agents)) - self.t_utils = RailEnvTransitions() - # TODO : bad hack for pylint 80 characters per line; shortened function - self.gtfotd = self.t_utils.get_transition_from_orientation_to_direction + self.trans = RailEnvTransitions() def get_agent_handles(self): return self.agents_handles @@ -177,7 +175,7 @@ class RailEnv: valid_movements = [] for direction in range(4): position = self.agents_position[i] - moves = self.t_utils.get_transitions_from_orientation( + moves = self.trans.get_transitions( self.rail[position[0]][position[1]], direction) for move_index in range(4): if moves[move_index]: @@ -272,7 +270,7 @@ class RailEnv: elif direction == 3: reverse_direction = 1 - valid_transition = self.gtfotd( + valid_transition = self.trans.get_transition( self.rail[pos[0]][pos[1]], reverse_direction, reverse_direction) @@ -295,7 +293,7 @@ class RailEnv: else: new_cell_isValid = False - transition_isValid = self.gtfotd( + transition_isValid = self.trans.get_transition( self.rail[pos[0]][pos[1]], direction, movement) @@ -364,7 +362,7 @@ class RailEnv: return 1 if node not in visited: visited.add(node) - moves = self.t_utils.get_transitions_from_orientation( + moves = self.trans.get_transitions( self.rail[node[0][0]][node[0][1]], node[1]) for move_index in range(4): if moves[move_index]: diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py index 9c4f05e..78e76f5 100644 --- a/flatland/core/transitions.py +++ b/flatland/core/transitions.py @@ -14,7 +14,7 @@ class Transitions: `orientation' and moving into direction `direction') """ - def get_transitions_from_orientation(self, cell_transition, orientation): + def get_transitions(self, cell_transition, orientation): """ Return a tuple of transitions available in a cell specified by `cell_transition' for an agent facing direction `orientation' @@ -39,8 +39,7 @@ class Transitions: """ raise NotImplementedError() - def set_transitions_from_orientation(self, cell_transition, orientation, - new_transitions): + def set_transitions(self, cell_transition, orientation, new_transitions): """ Return a `cell_transition' specification where the transitions available for an agent facing direction `orientation' are replaced @@ -68,8 +67,7 @@ class Transitions: """ raise NotImplementedError() - def get_transition_from_orientation_to_direction(self, cell_transition, - orientation, direction): + def get_transition(self, cell_transition, orientation, direction): """ Return the status of whether an agent oriented in directions `orientation' and inside a cell with transitions `cell_transition' @@ -96,11 +94,8 @@ class Transitions: """ raise NotImplementedError() - def set_transition_from_orientation_to_direction(self, - cell_transition, - orientation, - direction, - new_transition): + def set_transition(self, cell_transition, orientation, direction, + new_transition): """ Return a `cell_transition' specification where the status of whether an agent oriented in direction `orientation' and inside @@ -133,15 +128,14 @@ class Transitions: raise NotImplementedError() -class GridTransitions(Transitions): +class Grid4Transitions(Transitions): """ - GridTransitions class derived from Transitions. + Grid4Transitions class derived from Transitions. Special case of `Transitions' over a 2D-grid (FlatLand). Transitions are possible to neighboring cells on the grid if allowed. GridTransitions keeps track of valid transitions supplied as `transitions' - list, each represented as a bitmap of 16 (allow_diagonal_transitions=False) - or 64 bits (allow_diagonal_transitions=True). + list, each represented as a bitmap of 16 bits. Whether a transition is allowed or not depends on which direction an agent inside the cell is facing (0=North, 1=East, 2=South, 3=West) and which @@ -163,19 +157,10 @@ class GridTransitions(Transitions): In the example, the agent can move from North to South and viceversa. """ - def __init__(self, - transitions, - allow_diagonal_transitions=False - ): - - if allow_diagonal_transitions: - self.number_of_cell_neighbors = 8 - else: - self.number_of_cell_neighbors = 4 - + def __init__(self, transitions): self.transitions = transitions - def get_transitions_from_orientation(self, cell_transition, orientation): + def get_transitions(self, cell_transition, orientation): """ Get the 4 possible transitions ((N,E,S,W), 4 elements tuple if no diagonal transitions allowed) available for an agent oriented @@ -185,7 +170,7 @@ class GridTransitions(Transitions): Parameters ---------- cell_transition : int - 16 or 64 bits used to encode the valid transitions for a cell. + 16 bits used to encode the valid transitions for a cell. orientation : int Orientation of the agent inside the cell. @@ -195,28 +180,10 @@ class GridTransitions(Transitions): List of the validity of transitions in the cell. """ - if self.number_of_cell_neighbors == 4: - bits = (cell_transition >> ((3-orientation)*4)) - cell_transition = ((bits >> 3) & 1, (bits >> 2) & 1, - (bits >> 1) & 1, (bits) & 1) - elif self.number_of_cell_neighbors == 8: - bits = (cell_transition >> ((7-orientation)*8)) - cell_transition = ( - (bits >> 7) & 1, - (bits >> 6) & 1, - (bits >> 5) & 1, - (bits >> 4) & 1, - (bits >> 3) & 1, - (bits >> 2) & 1, - (bits >> 1) & 1, - (bits) & 1) - else: - raise NotImplementedError() + bits = (cell_transition >> ((3-orientation)*4)) + return ((bits >> 3) & 1, (bits >> 2) & 1, (bits >> 1) & 1, (bits) & 1) - return cell_transition - - def set_transitions_from_orientation(self, cell_transition, orientation, - new_transitions): + def set_transitions(self, cell_transition, orientation, new_transitions): """ Set the possible transitions (e.g., (N,E,S,W), 4 elements tuple if no diagonal transitions allowed) available for an agent @@ -227,7 +194,7 @@ class GridTransitions(Transitions): Parameters ---------- cell_transition : int - 16 or 64 bits used to encode the valid transitions for a cell. + 16 bits used to encode the valid transitions for a cell. orientation : int Orientation of the agent inside the cell. new_transitions : tuple @@ -241,43 +208,22 @@ class GridTransitions(Transitions): `orientation'. """ - if self.number_of_cell_neighbors == 4: - mask = (1 << ((4-orientation)*4)) - (1 << ((3-orientation)*4)) - negmask = ~mask - - new_transitions = \ - (new_transitions[0] & 1) << 3 | \ - (new_transitions[1] & 1) << 2 | \ - (new_transitions[2] & 1) << 1 | \ - (new_transitions[3] & 1) - - cell_transition = \ - ( - cell_transition & negmask) | \ - (new_transitions << ((3-orientation)*4)) - elif self.number_of_cell_neighbors == 8: - mask = (1 << ((8-orientation)*8)) - (1 << ((7-orientation)*8)) - negmask = ~mask - - new_transitions = \ - (new_transitions[0] & 1) << 7 | \ - (new_transitions[1] & 1) << 6 | \ - (new_transitions[2] & 1) << 5 | \ - (new_transitions[3] & 1) << 4 | \ - (new_transitions[4] & 1) << 3 | \ - (new_transitions[5] & 1) << 2 | \ - (new_transitions[6] & 1) << 1 | \ - (new_transitions[7] & 1) - - cell_transition = (cell_transition & negmask) | ( - new_transitions << ((7-orientation)*8)) - else: - raise NotImplementedError() + mask = (1 << ((4-orientation)*4)) - (1 << ((3-orientation)*4)) + negmask = ~mask + + new_transitions = \ + (new_transitions[0] & 1) << 3 | \ + (new_transitions[1] & 1) << 2 | \ + (new_transitions[2] & 1) << 1 | \ + (new_transitions[3] & 1) + + cell_transition = \ + (cell_transition & negmask) | \ + (new_transitions << ((3-orientation)*4)) return cell_transition - def get_transition_from_orientation_to_direction(self, cell_transition, - orientation, direction): + def get_transition(self, cell_transition, orientation, direction): """ Get the transition bit (1 value) that determines whether an agent oriented in direction `orientation' and inside a cell with transitions @@ -287,7 +233,7 @@ class GridTransitions(Transitions): Parameters ---------- cell_transition : int - 16 or 64 bits used to encode the valid transitions for a cell. + 16 bits used to encode the valid transitions for a cell. orientation : int Orientation of the agent inside the cell. direction : int @@ -299,14 +245,10 @@ class GridTransitions(Transitions): Validity of the requested transition: 0/1 allowed/not allowed. """ - return ((cell_transition >> - ((self.number_of_cell_neighbors-1-orientation) * - self.number_of_cell_neighbors)) >> - (self.number_of_cell_neighbors-1-direction)) & 1 - - def set_transition_from_orientation_to_direction(self, cell_transition, - orientation, direction, - new_transition): + return ((cell_transition >> ((4-1-orientation) * 4)) >> (4-1-direction)) & 1 + + def set_transition(self, cell_transition, orientation, direction, + new_transition): """ Set the transition bit (1 value) that determines whether an agent oriented in direction `orientation' and inside a cell with transitions @@ -316,7 +258,7 @@ class GridTransitions(Transitions): Parameters ---------- cell_transition : int - 16 or 64 bits used to encode the valid transitions for a cell. + 16 bits used to encode the valid transitions for a cell. orientation : int Orientation of the agent inside the cell. direction : int @@ -333,34 +275,27 @@ class GridTransitions(Transitions): """ if new_transition: - cell_transition |= \ - (1 << ((self.number_of_cell_neighbors-1-orientation) * - self.number_of_cell_neighbors + - (self.number_of_cell_neighbors - 1 - direction))) + cell_transition |= (1 << ((4-1-orientation) * 4 + + (4 - 1 - direction))) else: cell_transition &= \ - ~(1 << ((self.number_of_cell_neighbors-1-orientation) * - self.number_of_cell_neighbors + - (self.number_of_cell_neighbors - 1 - direction))) + ~(1 << ((4-1-orientation) * 4 + + (4 - 1 - direction))) return cell_transition def rotate_transition(self, cell_transition, rotation=0): """ - Clockwise-rotate a 16-bit or 64-bit transition bitmap by - rotation={0, 90, 180, 270} degrees in diagonal steps are not allowed, - or by rotation={0, 45, 90, 135, 180, 225, 270, 315} degrees if \ - they are. + Clockwise-rotate a 16-bit transition bitmap by + rotation={0, 90, 180, 270} degrees. Parameters ---------- cell_transition : int - 16 or 64 bits used to encode the valid transitions for a cell. + 16 bits used to encode the valid transitions for a cell. rotation : int Angle by which to clock-wise rotate the transition bits in - `cell_transition' by. I.e., rotation={0, 90, 180, 270} degrees in - diagonal steps are not allowed, or by - rotation={0, 45, 90, 135, 180, 225, 270, 315} degrees if they are. + `cell_transition' by. I.e., rotation={0, 90, 180, 270} degrees. Returns ------- @@ -369,48 +304,210 @@ class GridTransitions(Transitions): with the equivalent bitmap after rotation. """ - if self.number_of_cell_neighbors == 4: - # Rotate the individual bits in each block - value = cell_transition - rotation = rotation // 90 - for i in range(4): - block_tuple = self.get_transitions_from_orientation(value, i) - block_tuple = block_tuple[( - 4-rotation):] + block_tuple[:(4-rotation)] - value = self.set_transitions_from_orientation( - value, i, block_tuple) - - # Rotate the 4bits blocks - value = ((value & (2**(rotation*4)-1)) << - ((4-rotation)*4)) | (value >> (rotation*4)) - - cell_transition = value - - elif self.number_of_cell_neighbors == 8: - # TODO: WARNING: this part of the function has never been tested! - - # Rotate the individual bits in each block - value = cell_transition - rotation = rotation // 45 - for i in range(8): - block_tuple = self.get_transitions_from_orientation(value, i) - block_tuple = block_tuple[rotation:] + block_tuple[:rotation] - value = self.set_transitions_from_orientation( - value, i, block_tuple) - - # Rotate the 8bits blocks - value = ((value & (2**(rotation*8)-1)) << - ((8-rotation)*8)) | (value >> (rotation*8)) - - cell_transition = value + # Rotate the individual bits in each block + value = cell_transition + rotation = rotation // 90 + for i in range(4): + block_tuple = self.get_transitions(value, i) + block_tuple = block_tuple[( + 4-rotation):] + block_tuple[:(4-rotation)] + value = self.set_transitions(value, i, block_tuple) + + # Rotate the 4-bits blocks + value = ((value & (2**(rotation*4)-1)) << + ((4-rotation)*4)) | (value >> (rotation*4)) + + cell_transition = value + return cell_transition + + +class Grid8Transitions(Transitions): + """ + Grid8Transitions class derived from Transitions. + Special case of `Transitions' over a 2D-grid (FlatLand). + Transitions are possible to neighboring cells on the grid if allowed. + GridTransitions keeps track of valid transitions supplied as `transitions' + list, each represented as a bitmap of 64 bits. + + 0=North, 1=North-East, etc. + + """ + + def __init__(self, transitions): + self.transitions = transitions + + def get_transitions(self, cell_transition, orientation): + """ + Get the 8 possible transitions. + + Parameters + ---------- + cell_transition : int + 64 bits used to encode the valid transitions for a cell. + orientation : int + Orientation of the agent inside the cell. + + Returns + ------- + tuple + List of the validity of transitions in the cell. + + """ + bits = (cell_transition >> ((7-orientation)*8)) + cell_transition = ( + (bits >> 7) & 1, + (bits >> 6) & 1, + (bits >> 5) & 1, + (bits >> 4) & 1, + (bits >> 3) & 1, + (bits >> 2) & 1, + (bits >> 1) & 1, + (bits) & 1) + + return cell_transition + + def set_transitions(self, cell_transition, orientation, new_transitions): + """ + Set the possible transitions. + + Parameters + ---------- + cell_transition : int + 64 bits used to encode the valid transitions for a cell. + orientation : int + Orientation of the agent inside the cell. + new_transitions : tuple + Tuple of new transitions validitiy for the cell. + + Returns + ------- + int + An updated bitmap that replaces the original transitions validity + of `cell_transition' with `new_transitions', for the appropriate + `orientation'. + + """ + mask = (1 << ((8-orientation)*8)) - (1 << ((7-orientation)*8)) + negmask = ~mask + + new_transitions = \ + (new_transitions[0] & 1) << 7 | \ + (new_transitions[1] & 1) << 6 | \ + (new_transitions[2] & 1) << 5 | \ + (new_transitions[3] & 1) << 4 | \ + (new_transitions[4] & 1) << 3 | \ + (new_transitions[5] & 1) << 2 | \ + (new_transitions[6] & 1) << 1 | \ + (new_transitions[7] & 1) + + cell_transition = (cell_transition & negmask) | ( + new_transitions << ((7-orientation)*8)) + + return cell_transition + + def get_transition(self, cell_transition, orientation, direction): + """ + Get the transition bit (1 value) that determines whether an agent + oriented in direction `orientation' and inside a cell with transitions + `cell_transition' can move to the cell in direction `direction' + relative to the current cell. + + Parameters + ---------- + cell_transition : int + 64 bits used to encode the valid transitions for a cell. + orientation : int + Orientation of the agent inside the cell. + direction : int + Direction of movement whose validity is to be tested. + + Returns + ------- + int + Validity of the requested transition: 0/1 allowed/not allowed. + + """ + return ((cell_transition >> ((8-1-orientation) * 8)) >> + (8-1-direction)) & 1 + + def set_transition(self, cell_transition, orientation, direction, + new_transition): + """ + Set the transition bit (1 value) that determines whether an agent + oriented in direction `orientation' and inside a cell with transitions + `cell_transition' can move to the cell in direction `direction' + relative to the current cell. + + Parameters + ---------- + cell_transition : int + 64 bits used to encode the valid transitions for a cell. + orientation : int + Orientation of the agent inside the cell. + direction : int + Direction of movement whose validity is to be tested. + new_transition : int + Validity of the requested transition: 0/1 allowed/not allowed. + + Returns + ------- + int + An updated bitmap that replaces the original transitions validity + of `cell_transition' with `new_transitions', for the appropriate + `orientation'. + + """ + if new_transition: + cell_transition |= (1 << ((8-1-orientation) * 8 + + (8 - 1 - direction))) else: - raise NotImplementedError() + cell_transition &= ~(1 << ((8-1-orientation) * 8 + + (8 - 1 - direction))) + + return cell_transition + + def rotate_transition(self, cell_transition, rotation=0): + """ + Clockwise-rotate a 64-bit transition bitmap by + rotation={0, 45, 90, 135, 180, 225, 270, 315} degrees. + + Parameters + ---------- + cell_transition : int + 64 bits used to encode the valid transitions for a cell. + rotation : int + Angle by which to clock-wise rotate the transition bits in + `cell_transition' by. I.e., rotation={0, 45, 90, 135, 180, + 225, 270, 315} degrees. + + Returns + ------- + int + An updated bitmap that replaces the original transitions bits + with the equivalent bitmap after rotation. + + """ + # TODO: WARNING: this part of the function has never been tested! + + # Rotate the individual bits in each block + value = cell_transition + rotation = rotation // 45 + for i in range(8): + block_tuple = self.get_transitions(value, i) + block_tuple = block_tuple[rotation:] + block_tuple[:rotation] + value = self.set_transitions(value, i, block_tuple) + + # Rotate the 8bits blocks + value = ((value & (2**(rotation*8)-1)) << + ((8-rotation)*8)) | (value >> (rotation*8)) + + cell_transition = value return cell_transition -class RailEnvTransitions(GridTransitions): +class RailEnvTransitions(Grid4Transitions): """ Special case of `GridTransitions' over a 2D-grid, with a pre-defined set of transitions mimicking the types of real Swiss rail connections. @@ -450,6 +547,5 @@ class RailEnvTransitions(GridTransitions): def __init__(self): super(RailEnvTransitions, self).__init__( - transitions=self.transition_list, - allow_diagonal_transitions=False + transitions=self.transition_list ) diff --git a/flatland/utils/rail_env_generator.py b/flatland/utils/rail_env_generator.py index df180e7..4d73d52 100644 --- a/flatland/utils/rail_env_generator.py +++ b/flatland/utils/rail_env_generator.py @@ -2,8 +2,8 @@ The rail_env_generator module defines provides utilities to generate env bitmaps for the RailEnv environment. """ -import numpy as np import random +import numpy as np from flatland.core.transitions import RailEnvTransitions @@ -82,8 +82,7 @@ def generate_random_rail(width, height): for i in range(len(t_utils.transitions)-1): # don't include dead-ends all_transitions = 0 for dir_ in range(4): - trans = t_utils.get_transitions_from_orientation( - t_utils.transitions[i], dir_) + trans = t_utils.get_transitions(t_utils.transitions[i], dir_) all_transitions |= (trans[0] << 3) | \ (trans[1] << 2) | \ (trans[2] << 1) | \ @@ -148,8 +147,7 @@ def generate_random_rail(width, height): max_bit = 0 for k in range(4): max_bit |= \ - t_utils.get_transition_from_orientation_to_direction( - neigh_trans, k, el[1]) + t_utils.get_transition(neigh_trans, k, el[1]) if max_bit: valid_template[el[0]] = 1 diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 1a897d8..af5f68f 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -58,7 +58,7 @@ class RenderTool(object): # transition for next cell oTrans = self.env.rail[rcNext[0]][rcNext[1]] tbTrans = RailEnvTransitions. \ - get_transitions_from_orientation(oTrans, iDir) + get_transitions(oTrans, iDir) giTrans = np.where(tbTrans)[0] # RC list of transitions gTransRCAg = self.__class__.gTransRC[giTrans] @@ -106,7 +106,7 @@ class RenderTool(object): # TODO: suggest we provide an accessor in RailEnv oTrans = self.env.rail[rcPos] # transition for current cell - tbTrans = rt.RETrans.get_transitions_from_orientation(oTrans, iDir) + tbTrans = rt.RETrans.get_transitions(oTrans, iDir) giTrans = np.where(tbTrans)[0] # RC list of transitions # HACK: workaround dead-end transitions @@ -363,8 +363,7 @@ class RenderTool(object): # renderer.translate(c * CELL_PIXELS, r * CELL_PIXELS) if True: - tMoves = RETrans.get_transitions_from_orientation( - oCell, orientation) + tMoves = RETrans.get_transitions(oCell, orientation) # to_ori = (orientation + 2) % 4 for to_ori in range(4): diff --git a/tests/test_environments.py b/tests/test_environments.py index d5e7dd9..32f8784 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -2,24 +2,22 @@ # -*- coding: utf-8 -*- from flatland.core.env import RailEnv -from flatland.core.transitions import GridTransitions +from flatland.core.transitions import Grid4Transitions import numpy as np -import random """Tests for `flatland` package.""" - def test_rail_environment_single_agent(): cells = [int('0000000000000000', 2), # empty cell - Case 0 - int('1000000000100000', 2), # Case 1 - straight - int('1001001000100000', 2), # Case 2 - simple switch - int('1000010000100001', 2), # Case 3 - diamond drossing - int('1001011000100001', 2), # Case 4 - single slip switch - int('1100110000110011', 2), # Case 5 - double slip switch - int('0101001000000010', 2), # Case 6 - symmetrical switch - int('0010000000000000', 2)] # Case 7 - dead end + int('1000000000100000', 2), # Case 1 - straight + int('1001001000100000', 2), # Case 2 - simple switch + int('1000010000100001', 2), # Case 3 - diamond drossing + int('1001011000100001', 2), # Case 4 - single slip switch + int('1100110000110011', 2), # Case 5 - double slip switch + int('0101001000000010', 2), # Case 6 - symmetrical switch + int('0010000000000000', 2)] # Case 7 - dead end # We instantiate the following map on a 3x3 grid # _ _ @@ -27,7 +25,7 @@ def test_rail_environment_single_agent(): # | | | # \_/\_/ - transitions = GridTransitions([], False) + transitions = Grid4Transitions([]) vertical_line = cells[1] south_symmetrical_switch = cells[6] north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180) @@ -51,7 +49,7 @@ def test_rail_environment_single_agent(): # Check that trains are always initialized at a consistent position / direction. # They should always be able to go somewhere. - assert(transitions.get_transitions_from_orientation( + assert(transitions.get_transitions( rail_map[rail_env.agents_position[0]], rail_env.agents_direction[0]) != (0, 0, 0, 0)) diff --git a/tests/test_transitions.py b/tests/test_transitions.py index 2c59add..f68b836 100644 --- a/tests/test_transitions.py +++ b/tests/test_transitions.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- """Tests for `flatland` package.""" -from flatland.core.transitions import RailEnvTransitions, GridTransitions +from flatland.core.transitions import RailEnvTransitions, Grid8Transitions def test_valid_railenv_transitions(): @@ -14,36 +14,36 @@ def test_valid_railenv_transitions(): # 'W': 3} for i in range(2): - assert(rail_env_trans.get_transitions_from_orientation( + assert(rail_env_trans.get_transitions( int('1100110000110011', 2), i) == (1, 1, 0, 0)) - assert(rail_env_trans.get_transitions_from_orientation( + assert(rail_env_trans.get_transitions( int('1100110000110011', 2), 2+i) == (0, 0, 1, 1)) no_transition_cell = int('0000000000000000', 2) for i in range(4): - assert(rail_env_trans.get_transitions_from_orientation( + assert(rail_env_trans.get_transitions( no_transition_cell, i) == (0, 0, 0, 0)) # Facing south, going south - north_south_transition = rail_env_trans.set_transitions_from_orientation( + north_south_transition = rail_env_trans.set_transitions( no_transition_cell, 2, (0, 0, 1, 0)) - assert(rail_env_trans.set_transition_from_orientation_to_direction( + assert(rail_env_trans.set_transition( north_south_transition, 2, 2, 0) == no_transition_cell) - assert(rail_env_trans.get_transition_from_orientation_to_direction( + assert(rail_env_trans.get_transition( north_south_transition, 2, 2)) # Facing north, going east south_east_transition = \ - rail_env_trans.set_transition_from_orientation_to_direction( + rail_env_trans.set_transition( no_transition_cell, 0, 1, 1) - assert(rail_env_trans.get_transition_from_orientation_to_direction( + assert(rail_env_trans.get_transition( south_east_transition, 0, 1)) # The opposite transitions are not feasible - assert(not rail_env_trans.get_transition_from_orientation_to_direction( + assert(not rail_env_trans.get_transition( north_south_transition, 2, 0)) - assert(not rail_env_trans.get_transition_from_orientation_to_direction( + assert(not rail_env_trans.get_transition( south_east_transition, 2, 1)) east_west_transition = rail_env_trans.rotate_transition( @@ -52,10 +52,10 @@ def test_valid_railenv_transitions(): south_east_transition, 180) # Facing west, going west - assert(rail_env_trans.get_transition_from_orientation_to_direction( + assert(rail_env_trans.get_transition( east_west_transition, 3, 3)) # Facing south, going west - assert(rail_env_trans.get_transition_from_orientation_to_direction( + assert(rail_env_trans.get_transition( north_west_transition, 2, 3)) assert(south_east_transition == rail_env_trans.rotate_transition( @@ -63,16 +63,16 @@ def test_valid_railenv_transitions(): def test_diagonal_transitions(): - diagonal_trans_env = GridTransitions([], True) + diagonal_trans_env = Grid8Transitions([]) # Facing north, going north-east south_northeast_transition = int('01000000' + '0'*8*7, 2) - assert(diagonal_trans_env.get_transitions_from_orientation( + assert(diagonal_trans_env.get_transitions( south_northeast_transition, 0) == (0, 1, 0, 0, 0, 0, 0, 0)) # Allowing transition from north to southwest: Facing south, going SW north_southwest_transition = \ - diagonal_trans_env.set_transitions_from_orientation( + diagonal_trans_env.set_transitions( int('0' * 64, 2), 4, (0, 0, 0, 0, 0, 1, 0, 0)) assert(diagonal_trans_env.rotate_transition( -- GitLab