diff --git a/.gitignore b/.gitignore index 5cf2e905fe1b49ff7c6660efdc083c7cf69caf24..8ac21dd6ff8a775c07a2b1d7dd97db147ff5b357 100644 --- a/.gitignore +++ b/.gitignore @@ -71,6 +71,9 @@ target/ # Jupyter Notebook .ipynb_checkpoints +# PyCharm +.idea/ + # pyenv .python-version diff --git a/flatland/core/env.py b/flatland/core/env.py index 4a147d067d4510268a431b0e32ea291799ae70a0..2ecee638f1287e0ef27c1b2f405c87b03efcb576 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -108,7 +108,7 @@ class RailEnv: 0: do nothing 1: turn left and move to the next cell 2: move to the next cell in front of the agent - 3: turn righ tand move to the next cell + 3: turn right and move to the next cell Moving forward in a dead-end cell makes the agent turn 180 degrees and step to the cell it came from. @@ -276,6 +276,7 @@ class RailEnv: self.rail[pos[0]][pos[1]], reverse_direction, reverse_direction) + if valid_transition: direction = reverse_direction movement = direction @@ -285,7 +286,11 @@ class RailEnv: # Is it a legal move? 1) transition allows the movement in the # cell, 2) the new cell is not empty (case 0), 3) the cell is # free, i.e., no agent is currently in that cell - if self.rail[new_position[0]][new_position[1]] > 0: + if new_position[1] >= self.width or new_position[0] >= self.height or\ + new_position[0] < 0 or new_position[1] < 0: + new_cell_isValid = False + + elif self.rail[new_position[0]][new_position[1]] > 0: new_cell_isValid = True else: new_cell_isValid = False diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py index 467ad67a117b992bb6ab02b6f70bc83da3557081..9c4f05e3bf394cc88b05dae3f08513c70fe52b20 100644 --- a/flatland/core/transitions.py +++ b/flatland/core/transitions.py @@ -438,14 +438,15 @@ class RailEnvTransitions(GridTransitions): transitions available as a function of the agent's orientation (north, east, south, west) """ - transition_list = [int('0000000000000000', 2), - int('1000000000100000', 2), - int('1001001000100000', 2), - int('1000010000100001', 2), - int('1001011000100001', 2), - int('1100110000110011', 2), - int('0101001000000010', 2), - int('0000000000100000', 2)] + + transition_list = [int('0000000000000000', 2), # empty cell - Case 0 + int('1000000000100000', 2), # Case 1 - straight + int('1001001000100000', 2), # Case 2 - simple switch + int('1000010000100001', 2), # Case 3 - diamond drossing + int('1001011000100001', 2), # Case 4 - single slip switch + int('1100110000110011', 2), # Case 5 - double slip switch + int('0101001000000010', 2), # Case 6 - symmetrical switch + int('0010000000000000', 2)] # Case 7 - dead end def __init__(self): super(RailEnvTransitions, self).__init__( diff --git a/requirements_dev.txt b/requirements_dev.txt index 926f63047bb152a1bdbde3b0f370b2a57c3c84f7..70d99c53aead6b2ff46250a662c255bacda70c10 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -10,3 +10,5 @@ twine==1.12.1 pytest==3.8.2 pytest-runner==4.2 sphinx-rtd-theme==0.4.3 + +numpy==1.16.2 diff --git a/tests/test_environments.py b/tests/test_environments.py index ae4d5249e143769f41e607ccd477737ef0e440b9..435002b5af080293f4a739b12f6e5256c8780e78 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -1,11 +1,94 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +from flatland.core.env import RailEnv +from flatland.core.transitions import GridTransitions +import numpy as np +import random + """Tests for `flatland` package.""" -def test_base_environment(): - """Test example Transition.""" - a = True - b = True - assert a == b + +def test_rail_environment_single_agent(): + + cells = [int('0000000000000000', 2), # empty cell - Case 0 + int('1000000000100000', 2), # Case 1 - straight + int('1001001000100000', 2), # Case 2 - simple switch + int('1000010000100001', 2), # Case 3 - diamond drossing + int('1001011000100001', 2), # Case 4 - single slip switch + int('1100110000110011', 2), # Case 5 - double slip switch + int('0101001000000010', 2), # Case 6 - symmetrical switch + int('0010000000000000', 2)] # Case 7 - dead end + + # We instantiate the following map on a 3x3 grid + # _ _ + # / \/ \ + # | | | + # \_/\_/ + + transitions = GridTransitions([], False) + vertical_line = cells[1] + south_symmetrical_switch = cells[6] + north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180) + south_east_turn = int('0100000000000010', 2) # Simple turn not in the base transitions ? + south_west_turn = transitions.rotate_transition(south_east_turn, 90) + # print(bytes(south_west_turn)) + north_east_turn = transitions.rotate_transition(south_east_turn, 270) + north_west_turn = transitions.rotate_transition(south_east_turn, 180) + + rail_map = np.array([[south_east_turn, south_symmetrical_switch, south_west_turn], + [vertical_line, vertical_line, vertical_line], + [north_east_turn, north_symmetrical_switch, north_west_turn]], + dtype=np.uint16) + + rail_env = RailEnv(rail_map, number_of_agents=1) + for _ in range(200): + _ = rail_env.reset() + + # We do not care about target for the moment + rail_env.agents_target[0] = [-1, -1] + + # Check that trains are always initialized at a consistent position / direction. + # They should always be able to go somewhere. + assert(transitions.get_transitions_from_orientation( + rail_map[rail_env.agents_position[0]], + rail_env.agents_direction[0]) != (0, 0, 0, 0)) + + initial_pos = rail_env.agents_position[0] + + valid_active_actions_done = 0 + pos = initial_pos + while valid_active_actions_done < 6: + # We randomly select an action + action = np.random.randint(4) + + _, _, _, _ = rail_env.step({0: action}) + + prev_pos = pos + pos = rail_env.agents_position[0] + if prev_pos != pos: + valid_active_actions_done += 1 + + # After 6 movements on this railway network, the train should be back to its original + # position. + assert(initial_pos[0] == rail_env.agents_position[0][0]) + + # We check that the train always attains its target after some time + for _ in range(200): + _ = rail_env.reset() + + done = False + while not done: + # We randomly select an action + action = np.random.randint(4) + + _, _, dones, _ = rail_env.step({0: action}) + + done = dones['__all__'] + + + + + +