diff --git a/.gitignore b/.gitignore index 84229f457bb049aca8293530623abac81690b239..ca0c210c45d3b490676f2fc7716a5b781cd5c654 100644 --- a/.gitignore +++ b/.gitignore @@ -71,6 +71,9 @@ target/ # Jupyter Notebook .ipynb_checkpoints +# PyCharm +.idea/ + # pyenv .python-version diff --git a/flatland/core/env.py b/flatland/core/env.py index 4a147d067d4510268a431b0e32ea291799ae70a0..2ecee638f1287e0ef27c1b2f405c87b03efcb576 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -108,7 +108,7 @@ class RailEnv: 0: do nothing 1: turn left and move to the next cell 2: move to the next cell in front of the agent - 3: turn righ tand move to the next cell + 3: turn right and move to the next cell Moving forward in a dead-end cell makes the agent turn 180 degrees and step to the cell it came from. @@ -276,6 +276,7 @@ class RailEnv: self.rail[pos[0]][pos[1]], reverse_direction, reverse_direction) + if valid_transition: direction = reverse_direction movement = direction @@ -285,7 +286,11 @@ class RailEnv: # Is it a legal move? 1) transition allows the movement in the # cell, 2) the new cell is not empty (case 0), 3) the cell is # free, i.e., no agent is currently in that cell - if self.rail[new_position[0]][new_position[1]] > 0: + if new_position[1] >= self.width or new_position[0] >= self.height or\ + new_position[0] < 0 or new_position[1] < 0: + new_cell_isValid = False + + elif self.rail[new_position[0]][new_position[1]] > 0: new_cell_isValid = True else: new_cell_isValid = False diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py index dc4ffe82b6153573ba8528a8017b205dbd8da20d..9c4f05e3bf394cc88b05dae3f08513c70fe52b20 100644 --- a/flatland/core/transitions.py +++ b/flatland/core/transitions.py @@ -440,13 +440,13 @@ class RailEnvTransitions(GridTransitions): """ transition_list = [int('0000000000000000', 2), # empty cell - Case 0 - int('1000000000100000', 2), # Case 1 - straight - int('1001001000000000', 2), # Case 2 - simple switch - int('1000010000100001', 2), # Case 3 - diamond drossing - int('1001011000100001', 2), # Case 4 - single slip switch - int('1100110000110011', 2), # Case 5 - double slip switch - int('0101001000000010', 2), # Case 6 - symmetrical switch - int('0010000000000000', 2)] # Case 7 - dead end + int('1000000000100000', 2), # Case 1 - straight + int('1001001000100000', 2), # Case 2 - simple switch + int('1000010000100001', 2), # Case 3 - diamond drossing + int('1001011000100001', 2), # Case 4 - single slip switch + int('1100110000110011', 2), # Case 5 - double slip switch + int('0101001000000010', 2), # Case 6 - symmetrical switch + int('0010000000000000', 2)] # Case 7 - dead end def __init__(self): super(RailEnvTransitions, self).__init__( diff --git a/requirements_dev.txt b/requirements_dev.txt index cba711f85d729b9477cc9e4493fb96df7503802d..ad6c0add1a9a9b907c6b614e07a84aeeaac9ab72 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -11,3 +11,5 @@ twine==1.12.1 pytest==3.8.2 pytest-runner==4.2 sphinx-rtd-theme==0.4.3 + +numpy==1.16.2 diff --git a/tests/test_environments.py b/tests/test_environments.py index e95c398c15994ce9ff1481839a481377b9a5d262..435002b5af080293f4a739b12f6e5256c8780e78 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -4,15 +4,17 @@ from flatland.core.env import RailEnv from flatland.core.transitions import GridTransitions import numpy as np +import random """Tests for `flatland` package.""" -def test_rail_environment(): +def test_rail_environment_single_agent(): + cells = [int('0000000000000000', 2), # empty cell - Case 0 int('1000000000100000', 2), # Case 1 - straight - int('1001001000000000', 2), # Case 2 - simple switch + int('1001001000100000', 2), # Case 2 - simple switch int('1000010000100001', 2), # Case 3 - diamond drossing int('1001011000100001', 2), # Case 4 - single slip switch int('1100110000110011', 2), # Case 5 - double slip switch @@ -29,8 +31,9 @@ def test_rail_environment(): vertical_line = cells[1] south_symmetrical_switch = cells[6] north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180) - south_east_turn = int('0100000000100000', 2) # Simple turn not in the base transitions ? + south_east_turn = int('0100000000000010', 2) # Simple turn not in the base transitions ? south_west_turn = transitions.rotate_transition(south_east_turn, 90) + # print(bytes(south_west_turn)) north_east_turn = transitions.rotate_transition(south_east_turn, 270) north_west_turn = transitions.rotate_transition(south_east_turn, 180) @@ -40,12 +43,52 @@ def test_rail_environment(): dtype=np.uint16) rail_env = RailEnv(rail_map, number_of_agents=1) + for _ in range(200): + _ = rail_env.reset() + + # We do not care about target for the moment + rail_env.agents_target[0] = [-1, -1] - # Check that trains are always initialized at a consistent position / direction. - # They should always be able to go somewhere. - for _ in range(1000): - obs = rail_env.reset() + # Check that trains are always initialized at a consistent position / direction. + # They should always be able to go somewhere. assert(transitions.get_transitions_from_orientation( rail_map[rail_env.agents_position[0]], rail_env.agents_direction[0]) != (0, 0, 0, 0)) + initial_pos = rail_env.agents_position[0] + + valid_active_actions_done = 0 + pos = initial_pos + while valid_active_actions_done < 6: + # We randomly select an action + action = np.random.randint(4) + + _, _, _, _ = rail_env.step({0: action}) + + prev_pos = pos + pos = rail_env.agents_position[0] + if prev_pos != pos: + valid_active_actions_done += 1 + + # After 6 movements on this railway network, the train should be back to its original + # position. + assert(initial_pos[0] == rail_env.agents_position[0][0]) + + # We check that the train always attains its target after some time + for _ in range(200): + _ = rail_env.reset() + + done = False + while not done: + # We randomly select an action + action = np.random.randint(4) + + _, _, dones, _ = rail_env.step({0: action}) + + done = dones['__all__'] + + + + + +