diff --git a/.gitignore b/.gitignore index 5cf2e905fe1b49ff7c6660efdc083c7cf69caf24..8ac21dd6ff8a775c07a2b1d7dd97db147ff5b357 100644 --- a/.gitignore +++ b/.gitignore @@ -71,6 +71,9 @@ target/ # Jupyter Notebook .ipynb_checkpoints +# PyCharm +.idea/ + # pyenv .python-version diff --git a/examples/temporary_example.py b/examples/temporary_example.py index ffc9a5de5a4c525a15588f2859477719c83ea9ef..cd6d42de0c86ea39b908d878e23e83efa7fff823 100644 --- a/examples/temporary_example.py +++ b/examples/temporary_example.py @@ -4,96 +4,10 @@ import matplotlib.pyplot as plt from flatland.core.env import RailEnv from flatland.utils.rail_env_generator import * +from flatland.utils.rendertools import * -random.seed(100) -np.random.seed(100) - - -def pyplot_draw_square(center, size, color): - x0 = center[0] - size/2 - x1 = center[0] + size/2 - y0 = center[1] - size/2 - y1 = center[1] + size/2 - plt.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color) - - -def pyplot_render_env(env): - cell_size = 10 - - plt.figure() - - # Draw cells grid - grid_color = [0.95, 0.95, 0.95] - for r in range(env.height+1): - plt.plot([0, (env.width+1)*cell_size], - [-r*cell_size, -r*cell_size], color=grid_color) - for c in range(env.width+1): - plt.plot([c*cell_size, c*cell_size], - [0, -(env.height+1)*cell_size], color=grid_color) - - # Draw each cell independently - for r in range(env.height): - for c in range(env.width): - trans_ = env.rail[r][c] - - x0 = c*cell_size - x1 = (c+1)*cell_size - y0 = -r*cell_size - y1 = -(r+1)*cell_size - - coords = [((x0+x1) / 2.0, y0), (x1, (y0+y1) / 2.0), - ((x0+x1) / 2.0, y1), (x0, (y0+y1) / 2.0)] - - for orientation in range(4): - from_ori = (orientation + 2) % 4 - from_ = coords[from_ori] - - # Special Case 7, with a single bit; terminate at center - nbits = 0 - tmp = trans_ - - while tmp > 0: - nbits += (tmp & 1) - tmp = tmp >> 1 - - if nbits == 1: - from_ = ((x0+x1) / 2.0, (y0+y1) / 2.0) - - moves = env.t_utils.get_transitions_from_orientation( - env.rail[r][c], orientation) - for moves_i in range(4): - if moves[moves_i]: - to = coords[moves_i] - plt.plot([from_[0], to[0]], [from_[1], to[1]], 'k') - - # Draw each agent + its orientation + its target - cmap = plt.get_cmap('hsv', lut=env.number_of_agents+1) - for i in range(env.number_of_agents): - pyplot_draw_square((env.agents_position[i][1] * cell_size+cell_size/2, - -env.agents_position[i][0] * cell_size-cell_size/2), - cell_size / 8, cmap(i)) - for i in range(env.number_of_agents): - pyplot_draw_square((env.agents_target[i][1] * cell_size+cell_size/2, - -env.agents_target[i][0] * cell_size-cell_size/2), - cell_size / 3, [c for c in cmap(i)]) - - # orientation is a line connecting the center of the cell to the side - # of the square of the agent - new_position = env._new_position(env.agents_position[i], - env.agents_direction[i]) - new_position = ((new_position[0]+env.agents_position[i][0])/2 * - cell_size, - (new_position[1]+env.agents_position[i][1])/2 * - cell_size) - - plt.plot([env.agents_position[i][1] * cell_size + cell_size/2, - new_position[1] + cell_size/2], - [-env.agents_position[i][0] * cell_size-cell_size/2, - -new_position[0] - cell_size/2], color=cmap(i), linewidth=2.0) - - plt.xlim([0, env.width * cell_size]) - plt.ylim([-env.height * cell_size, 0]) - plt.show() +random.seed(1) +np.random.seed(1) # Example generate a random rail @@ -102,7 +16,8 @@ rail = generate_random_rail(20, 20) env = RailEnv(rail, number_of_agents=10) env.reset() -pyplot_render_env(env) +env_renderer = RenderTool(env) +env_renderer.renderEnv() # Example generate a rail given a manual specification, @@ -121,7 +36,8 @@ env.agents_position = [[1, 4]] env.agents_target = [[1, 1]] env.agents_direction = [1] -pyplot_render_env(env) +env_renderer = RenderTool(env) +env_renderer.renderEnv() print("Manual control: s=perform step, q=quit, [agent id] [1-2-3 action] \ @@ -148,4 +64,4 @@ for step in range(100): i = i+1 i += 1 - pyplot_render_env(env) + env_renderer.renderEnv() diff --git a/flatland/core/env.py b/flatland/core/env.py index 4a147d067d4510268a431b0e32ea291799ae70a0..2ecee638f1287e0ef27c1b2f405c87b03efcb576 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -108,7 +108,7 @@ class RailEnv: 0: do nothing 1: turn left and move to the next cell 2: move to the next cell in front of the agent - 3: turn righ tand move to the next cell + 3: turn right and move to the next cell Moving forward in a dead-end cell makes the agent turn 180 degrees and step to the cell it came from. @@ -276,6 +276,7 @@ class RailEnv: self.rail[pos[0]][pos[1]], reverse_direction, reverse_direction) + if valid_transition: direction = reverse_direction movement = direction @@ -285,7 +286,11 @@ class RailEnv: # Is it a legal move? 1) transition allows the movement in the # cell, 2) the new cell is not empty (case 0), 3) the cell is # free, i.e., no agent is currently in that cell - if self.rail[new_position[0]][new_position[1]] > 0: + if new_position[1] >= self.width or new_position[0] >= self.height or\ + new_position[0] < 0 or new_position[1] < 0: + new_cell_isValid = False + + elif self.rail[new_position[0]][new_position[1]] > 0: new_cell_isValid = True else: new_cell_isValid = False diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py index 467ad67a117b992bb6ab02b6f70bc83da3557081..9c4f05e3bf394cc88b05dae3f08513c70fe52b20 100644 --- a/flatland/core/transitions.py +++ b/flatland/core/transitions.py @@ -438,14 +438,15 @@ class RailEnvTransitions(GridTransitions): transitions available as a function of the agent's orientation (north, east, south, west) """ - transition_list = [int('0000000000000000', 2), - int('1000000000100000', 2), - int('1001001000100000', 2), - int('1000010000100001', 2), - int('1001011000100001', 2), - int('1100110000110011', 2), - int('0101001000000010', 2), - int('0000000000100000', 2)] + + transition_list = [int('0000000000000000', 2), # empty cell - Case 0 + int('1000000000100000', 2), # Case 1 - straight + int('1001001000100000', 2), # Case 2 - simple switch + int('1000010000100001', 2), # Case 3 - diamond drossing + int('1001011000100001', 2), # Case 4 - single slip switch + int('1100110000110011', 2), # Case 5 - double slip switch + int('0101001000000010', 2), # Case 6 - symmetrical switch + int('0010000000000000', 2)] # Case 7 - dead end def __init__(self): super(RailEnvTransitions, self).__init__( diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 8a8c64eabc5aaf4be0e02381f1c41a4493bdf6bd..1a897d819c8579189d3752adb12436fce5845318 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -413,7 +413,13 @@ class RenderTool(object): plt.ylim([-env.height * cell_size, 0]) plt.xticks(np.linspace(0, env.width * cell_size, env.width+1)) - plt.yticks(np.linspace(-env.height * cell_size, 0, env.height+1)) + plt.yticks(np.linspace(-env.height * cell_size, 0, env.height+1), + np.abs(np.linspace(-env.height * cell_size, + 0, env.height+1))) + + plt.xlim([0, env.width * cell_size]) + plt.ylim([-env.height * cell_size, 0]) + plt.show() def _draw_square(self, center, size, color): x0 = center[0]-size/2 diff --git a/requirements_dev.txt b/requirements_dev.txt index 926f63047bb152a1bdbde3b0f370b2a57c3c84f7..70d99c53aead6b2ff46250a662c255bacda70c10 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -10,3 +10,5 @@ twine==1.12.1 pytest==3.8.2 pytest-runner==4.2 sphinx-rtd-theme==0.4.3 + +numpy==1.16.2 diff --git a/tests/test_environments.py b/tests/test_environments.py index ae4d5249e143769f41e607ccd477737ef0e440b9..d5e7dd9c9fc350f11482b684572858b34ac83829 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -1,11 +1,94 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +from flatland.core.env import RailEnv +from flatland.core.transitions import GridTransitions +import numpy as np +import random + """Tests for `flatland` package.""" -def test_base_environment(): - """Test example Transition.""" - a = True - b = True - assert a == b + +def test_rail_environment_single_agent(): + + cells = [int('0000000000000000', 2), # empty cell - Case 0 + int('1000000000100000', 2), # Case 1 - straight + int('1001001000100000', 2), # Case 2 - simple switch + int('1000010000100001', 2), # Case 3 - diamond drossing + int('1001011000100001', 2), # Case 4 - single slip switch + int('1100110000110011', 2), # Case 5 - double slip switch + int('0101001000000010', 2), # Case 6 - symmetrical switch + int('0010000000000000', 2)] # Case 7 - dead end + + # We instantiate the following map on a 3x3 grid + # _ _ + # / \/ \ + # | | | + # \_/\_/ + + transitions = GridTransitions([], False) + vertical_line = cells[1] + south_symmetrical_switch = cells[6] + north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180) + south_east_turn = int('0100000000000010', 2) # Simple turn not in the base transitions ? + south_west_turn = transitions.rotate_transition(south_east_turn, 90) + # print(bytes(south_west_turn)) + north_east_turn = transitions.rotate_transition(south_east_turn, 270) + north_west_turn = transitions.rotate_transition(south_east_turn, 180) + + rail_map = np.array([[south_east_turn, south_symmetrical_switch, south_west_turn], + [vertical_line, vertical_line, vertical_line], + [north_east_turn, north_symmetrical_switch, north_west_turn]], + dtype=np.uint16) + + rail_env = RailEnv(rail_map, number_of_agents=1) + for _ in range(200): + _ = rail_env.reset() + + # We do not care about target for the moment + rail_env.agents_target[0] = [-1, -1] + + # Check that trains are always initialized at a consistent position / direction. + # They should always be able to go somewhere. + assert(transitions.get_transitions_from_orientation( + rail_map[rail_env.agents_position[0]], + rail_env.agents_direction[0]) != (0, 0, 0, 0)) + + initial_pos = rail_env.agents_position[0] + + valid_active_actions_done = 0 + pos = initial_pos + while valid_active_actions_done < 6: + # We randomly select an action + action = np.random.randint(4) + + _, _, _, _ = rail_env.step({0: action}) + + prev_pos = pos + pos = rail_env.agents_position[0] + if prev_pos != pos: + valid_active_actions_done += 1 + + # After 6 movements on this railway network, the train should be back to its original + # position. + assert(initial_pos[0] == rail_env.agents_position[0][0]) + + # We check that the train always attains its target after some time + for _ in range(10): + _ = rail_env.reset() + + done = False + while not done: + # We randomly select an action + action = np.random.randint(4) + + _, _, dones, _ = rail_env.step({0: action}) + + done = dones['__all__'] + + + + + +