From 39b33de10492aafaf15b3128294f37a135fe1596 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Wed, 19 Jun 2019 15:43:17 +0200 Subject: [PATCH] #62 first steps unit test coverage --- flatland/core/transitions.py | 8 ++------ flatland/envs/generators.py | 20 ++++++++++++-------- tests/test_flatland_core_transitions.py | 25 +++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 14 deletions(-) diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py index 6c38a39..1c3c924 100644 --- a/flatland/core/transitions.py +++ b/flatland/core/transitions.py @@ -556,12 +556,8 @@ class RailEnvTransitions(Grid4Transitions): In the example, the agent can move from North to South and viceversa. """ - """ - transitions[] is indexed by case type/id, and returns the 4x4-bit [NESW] - transitions available as a function of the agent's orientation - (north, east, south, west) - """ - + # Contains the basic transitions; + # the set of all valid transitions is obtained by successive 90-degree rotation of one of these basic transitions. transition_list = [int('0000000000000000', 2), # empty cell - Case 0 int('1000000000100000', 2), # Case 1 - straight int('1001001000100000', 2), # Case 2 - simple switch diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index f644bc1..08d99c2 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -153,7 +153,7 @@ def rail_from_manual_specifications_generator(rail_spec): Parameters ------- rail_spec : list of list of tuples - List (rows) of lists (columns) of tuples, each specifying a cell for + List (rows) of lists (columns) of tuples, each specifying a rail_spec_of_cell for the RailEnv environment as (cell_type, rotation), with rotation being clock-wise and in [0, 90, 180, 270]. @@ -161,23 +161,27 @@ def rail_from_manual_specifications_generator(rail_spec): ------- function Generator function that always returns a GridTransitionMap object with - the matrix of correct 16-bit bitmaps for each cell. + the matrix of correct 16-bit bitmaps for each rail_spec_of_cell. """ def generator(width, height, num_agents, num_resets=0): - t_utils = RailEnvTransitions() + rail_env_transitions = RailEnvTransitions() height = len(rail_spec) width = len(rail_spec[0]) - rail = GridTransitionMap(width=width, height=height, transitions=t_utils) + rail = GridTransitionMap(width=width, height=height, transitions=rail_env_transitions) for r in range(height): for c in range(width): - cell = rail_spec[r][c] - if cell[0] < 0 or cell[0] >= len(t_utils.transitions): - print("ERROR - invalid cell type=", cell[0]) + rail_spec_of_cell = rail_spec[r][c] + index_basic_type_of_cell_ = rail_spec_of_cell[0] + rotation_cell_ = rail_spec_of_cell[1] + if index_basic_type_of_cell_ < 0 or index_basic_type_of_cell_ >= len(rail_env_transitions.transitions): + print("ERROR - invalid rail_spec_of_cell type=", index_basic_type_of_cell_) return [] - rail.set_transitions((r, c), t_utils.rotate_transition(t_utils.transitions[cell[0]], cell[1])) + basic_type_of_cell_ = rail_env_transitions.transitions[index_basic_type_of_cell_] + effective_transition_cell = rail_env_transitions.rotate_transition(basic_type_of_cell_, rotation_cell_) + rail.set_transitions((r, c), effective_transition_cell) agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail( rail, diff --git a/tests/test_flatland_core_transitions.py b/tests/test_flatland_core_transitions.py index 9d02553..c32b038 100644 --- a/tests/test_flatland_core_transitions.py +++ b/tests/test_flatland_core_transitions.py @@ -8,6 +8,31 @@ from flatland.core.transitions import RailEnvTransitions, Grid8Transitions from flatland.envs.env_utils import validate_new_transition +def test_rotate_railenv_transition(): + rail_env_transitions = RailEnvTransitions() + transition_cycles = [ + # empty cell - Case 0 + [int('0000000000000000', 2), int('0000000000000000', 2), int('0000000000000000', 2), int('0000000000000000', 2)], + # Case 1 - straight + [int('1000000000100000', 2), int('0000000100000100', 2)], + ] + + for cycle in transition_cycles: + for i in range(4): + assert rail_env_transitions.rotate_transition(cycle[0], i) == cycle[i % len(cycle)] + + # + # int('1001001000100000', 2), # Case 2 - simple switch + # int('1000010000100001', 2), # Case 3 - diamond drossing + # int('1001011000100001', 2), # Case 4 - single slip + # int('1100110000110011', 2), # Case 5 - double slip + # int('0101001000000010', 2), # Case 6 - symmetrical + # int('0010000000000000', 2), # Case 7 - dead end + # int('0100000000000010', 2), # Case 1b (8) - simple turn right + # int('0001001000000000', 2), # Case 1c (9) - simple turn left + # int('1100000000100010', 2)] # Case 2b (10) - simple switch mirrored + + def test_is_valid_railenv_transitions(): rail_env_trans = RailEnvTransitions() transition_list = rail_env_trans.transitions -- GitLab