Skip to content
Snippets Groups Projects
Commit 94293fa8 authored by gmollard's avatar gmollard
Browse files

Simple tests for rail_env

parent 62395962
No related branches found
No related tags found
No related merge requests found
......@@ -71,6 +71,9 @@ target/
# Jupyter Notebook
.ipynb_checkpoints
# PyCharm
.idea/
# pyenv
.python-version
......
......@@ -108,7 +108,7 @@ class RailEnv:
0: do nothing
1: turn left and move to the next cell
2: move to the next cell in front of the agent
3: turn righ tand move to the next cell
3: turn right and move to the next cell
Moving forward in a dead-end cell makes the agent turn 180 degrees and step
to the cell it came from.
......@@ -276,6 +276,7 @@ class RailEnv:
self.rail[pos[0]][pos[1]],
reverse_direction,
reverse_direction)
if valid_transition:
direction = reverse_direction
movement = direction
......@@ -285,7 +286,11 @@ class RailEnv:
# Is it a legal move? 1) transition allows the movement in the
# cell, 2) the new cell is not empty (case 0), 3) the cell is
# free, i.e., no agent is currently in that cell
if self.rail[new_position[0]][new_position[1]] > 0:
if new_position[1] >= self.width or new_position[0] >= self.height or\
new_position[0] < 0 or new_position[1] < 0:
new_cell_isValid = False
elif self.rail[new_position[0]][new_position[1]] > 0:
new_cell_isValid = True
else:
new_cell_isValid = False
......
......@@ -440,13 +440,13 @@ class RailEnvTransitions(GridTransitions):
"""
transition_list = [int('0000000000000000', 2), # empty cell - Case 0
int('1000000000100000', 2), # Case 1 - straight
int('1001001000000000', 2), # Case 2 - simple switch
int('1000010000100001', 2), # Case 3 - diamond drossing
int('1001011000100001', 2), # Case 4 - single slip switch
int('1100110000110011', 2), # Case 5 - double slip switch
int('0101001000000010', 2), # Case 6 - symmetrical switch
int('0010000000000000', 2)] # Case 7 - dead end
int('1000000000100000', 2), # Case 1 - straight
int('1001001000100000', 2), # Case 2 - simple switch
int('1000010000100001', 2), # Case 3 - diamond drossing
int('1001011000100001', 2), # Case 4 - single slip switch
int('1100110000110011', 2), # Case 5 - double slip switch
int('0101001000000010', 2), # Case 6 - symmetrical switch
int('0010000000000000', 2)] # Case 7 - dead end
def __init__(self):
super(RailEnvTransitions, self).__init__(
......
......@@ -11,3 +11,5 @@ twine==1.12.1
pytest==3.8.2
pytest-runner==4.2
sphinx-rtd-theme==0.4.3
numpy==1.16.2
......@@ -4,15 +4,17 @@
from flatland.core.env import RailEnv
from flatland.core.transitions import GridTransitions
import numpy as np
import random
"""Tests for `flatland` package."""
def test_rail_environment():
def test_rail_environment_single_agent():
cells = [int('0000000000000000', 2), # empty cell - Case 0
int('1000000000100000', 2), # Case 1 - straight
int('1001001000000000', 2), # Case 2 - simple switch
int('1001001000100000', 2), # Case 2 - simple switch
int('1000010000100001', 2), # Case 3 - diamond drossing
int('1001011000100001', 2), # Case 4 - single slip switch
int('1100110000110011', 2), # Case 5 - double slip switch
......@@ -29,8 +31,9 @@ def test_rail_environment():
vertical_line = cells[1]
south_symmetrical_switch = cells[6]
north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180)
south_east_turn = int('0100000000100000', 2) # Simple turn not in the base transitions ?
south_east_turn = int('0100000000000010', 2) # Simple turn not in the base transitions ?
south_west_turn = transitions.rotate_transition(south_east_turn, 90)
# print(bytes(south_west_turn))
north_east_turn = transitions.rotate_transition(south_east_turn, 270)
north_west_turn = transitions.rotate_transition(south_east_turn, 180)
......@@ -40,12 +43,52 @@ def test_rail_environment():
dtype=np.uint16)
rail_env = RailEnv(rail_map, number_of_agents=1)
for _ in range(200):
_ = rail_env.reset()
# We do not care about target for the moment
rail_env.agents_target[0] = [-1, -1]
# Check that trains are always initialized at a consistent position / direction.
# They should always be able to go somewhere.
for _ in range(1000):
obs = rail_env.reset()
# Check that trains are always initialized at a consistent position / direction.
# They should always be able to go somewhere.
assert(transitions.get_transitions_from_orientation(
rail_map[rail_env.agents_position[0]],
rail_env.agents_direction[0]) != (0, 0, 0, 0))
initial_pos = rail_env.agents_position[0]
valid_active_actions_done = 0
pos = initial_pos
while valid_active_actions_done < 6:
# We randomly select an action
action = np.random.randint(4)
_, _, _, _ = rail_env.step({0: action})
prev_pos = pos
pos = rail_env.agents_position[0]
if prev_pos != pos:
valid_active_actions_done += 1
# After 6 movements on this railway network, the train should be back to its original
# position.
assert(initial_pos[0] == rail_env.agents_position[0][0])
# We check that the train always attains its target after some time
for _ in range(200):
_ = rail_env.reset()
done = False
while not done:
# We randomly select an action
action = np.random.randint(4)
_, _, dones, _ = rail_env.step({0: action})
done = dones['__all__']
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment