Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • flatland/flatland
  • stefan_otte/flatland
  • jiaodaxiaozi/flatland
  • sfwatergit/flatland
  • utozx126/flatland
  • ChenKuanSun/flatland
  • ashivani/flatland
  • minhhoa/flatland
  • pranjal_dhole/flatland
  • darthgera123/flatland
  • rivesunder/flatland
  • thomaslecat/flatland
  • joel_joseph/flatland
  • kchour/flatland
  • alex_zharichenko/flatland
  • yoogottamk/flatland
  • troye_fang/flatland
  • elrichgro/flatland
  • jun_jin/flatland
  • nimishsantosh107/flatland
20 results
Show changes
"""Test speed initialization by a map of speeds and their corresponding ratios."""
import numpy as np
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.line_generators import speed_initialization_helper, sparse_line_generator
def test_speed_initialization_helper():
random_generator = np.random.RandomState()
random_generator.seed(10)
speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.3}
actual_speeds = speed_initialization_helper(10, speed_ratio_map, np_random=random_generator)
# seed makes speed_initialization_helper deterministic -> check generated speeds.
assert actual_speeds == [3, 1, 2, 3, 2, 1, 1, 3, 1, 1]
def test_rail_env_speed_intializer():
speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.1, 5: 0.2}
env = RailEnv(width=50, height=50,
rail_generator=sparse_rail_generator(), line_generator=sparse_line_generator(),
number_of_agents=10)
env.reset()
actual_speeds = list(map(lambda agent: agent.speed_counter.speed, env.agents))
expected_speed_set = set(speed_ratio_map.keys())
# check that the number of speeds generated is correct
assert len(actual_speeds) == env.get_num_agents()
# check that only the speeds defined are generated
assert all({(actual_speed in expected_speed_set) for actual_speed in actual_speeds})
"""Test Utils."""
from typing import List, Tuple, Optional
import numpy as np
from attr import attrs, attrib
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.malfunction_generators import MalfunctionParameters, malfunction_from_params
from flatland.envs.rail_env import RailEnvActions, RailEnv
from flatland.envs.rail_generators import RailGenerator
from flatland.envs.line_generators import LineGenerator
from flatland.utils.rendertools import RenderTool
from flatland.envs.persistence import RailEnvPersister
from flatland.envs.step_utils.states import TrainState
from flatland.envs.step_utils.speed_counter import SpeedCounter
@attrs
class Replay(object):
position = attrib(type=Tuple[int, int])
direction = attrib(type=Grid4TransitionsEnum)
action = attrib(type=RailEnvActions)
malfunction = attrib(default=0, type=int)
set_malfunction = attrib(default=None, type=Optional[int])
reward = attrib(default=None, type=Optional[float])
state = attrib(default=None, type=Optional[TrainState])
@attrs
class ReplayConfig(object):
replay = attrib(type=List[Replay])
target = attrib(type=Tuple[int, int])
speed = attrib(type=float)
initial_position = attrib(type=Tuple[int, int])
initial_direction = attrib(type=Grid4TransitionsEnum)
# ensure that env is working correctly with start/stop/invalidaction penalty different from 0
def set_penalties_for_replay(env: RailEnv):
env.step_penalty = -7
env.start_penalty = -13
env.stop_penalty = -19
env.invalid_action_penalty = -29
def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False, activate_agents=True,
skip_reward_check=False, set_ready_to_depart=False, skip_action_required_check=False):
"""
Runs the replay configs and checks assertions.
*Initially*
- The `initial_position`, `initial_direction`, `target` and `speed` are taken from the `ReplayConfig` to initialize the agents.
*Before each step*
- `position` is verfified
- `direction` is verified
- `status` is verified (optionally, only if not `None` in `Replay`)
- `set_malfunction` is applied (optionally, only if not `None` in `Replay`)
- `malfunction` is verified
- `action` must only be provided if action_required from previous step (initally all True)
*Step*
- performed with the given `action`
*After each step*
- `reward` is verified after step
Parameters
----------
activate_agents: should the agents directly be activated when the environment is initially setup by `reset()`?
env: the environment; is `reset()` to set the agents' intial position, direction, target and speed
test_configs: the `ReplayConfig`s, one for each agent
rendering: should be rendered during replay?
"""
if rendering:
renderer = RenderTool(env)
renderer.render_env(show=True, frames=False, show_observations=False)
info_dict = {
'action_required': [True for _ in test_configs]
}
for step in range(len(test_configs[0].replay)):
if step == 0:
for a, test_config in enumerate(test_configs):
agent: EnvAgent = env.agents[a]
# set the initial position
agent.initial_position = test_config.initial_position
agent.initial_direction = test_config.initial_direction
agent.direction = test_config.initial_direction
agent.target = test_config.target
agent.speed_counter = SpeedCounter(speed=test_config.speed)
env.reset(False, False)
if set_ready_to_depart:
# Set all agents to ready to depart
for i_agent in range(len(env.agents)):
env.agents[i_agent].earliest_departure = 0
env.agents[i_agent]._set_state(TrainState.READY_TO_DEPART)
elif activate_agents:
for a_idx in range(len(env.agents)):
env.agents[a_idx].position = env.agents[a_idx].initial_position
env.agents[a_idx]._set_state(TrainState.MOVING)
def _assert(a, actual, expected, msg):
print("[{}] verifying {} on agent {}: actual={}, expected={}".format(step, msg, a, actual, expected))
assert (actual == expected) or (
np.allclose(actual, expected)), "[{}] agent {} {}: actual={}, expected={}".format(step, a, msg,
actual,
expected)
action_dict = {}
for a, test_config in enumerate(test_configs):
agent: EnvAgent = env.agents[a]
replay = test_config.replay[step]
# if not agent.position == replay.position:
# import pdb; pdb.set_trace()
_assert(a, agent.position, replay.position, 'position')
_assert(a, agent.direction, replay.direction, 'direction')
if replay.state is not None:
_assert(a, agent.state, replay.state, 'state')
if replay.action is not None:
if not skip_action_required_check:
assert info_dict['action_required'][
a] == True or agent.state == TrainState.READY_TO_DEPART, "[{}] agent {} expecting action_required={} or agent status READY_TO_DEPART".format(
step, a, True)
action_dict[a] = replay.action
else:
if not skip_action_required_check:
assert info_dict['action_required'][
a] == False, "[{}] agent {} expecting action_required={}, but found {}".format(
step, a, False, info_dict['action_required'][a])
if replay.set_malfunction is not None:
# As we force malfunctions on the agents we have to set a positive rate that the env
# recognizes the agent as potentially malfuncitoning
# We also set next malfunction to infitiy to avoid interference with our tests
env.agents[a].malfunction_handler._set_malfunction_down_counter(replay.set_malfunction)
_assert(a, agent.malfunction_handler.malfunction_down_counter, replay.malfunction, 'malfunction')
print(step)
_, rewards_dict, _, info_dict = env.step(action_dict)
# import pdb; pdb.set_trace()
if rendering:
renderer.render_env(show=True, show_observations=True)
for a, test_config in enumerate(test_configs):
replay = test_config.replay[step]
if not skip_reward_check:
_assert(a, rewards_dict[a], replay.reward, 'reward')
def create_and_save_env(file_name: str, line_generator: LineGenerator, rail_generator: RailGenerator):
stochastic_data = MalfunctionParameters(malfunction_rate=1000, # Rate of malfunction occurence
min_duration=15, # Minimal duration of malfunction
max_duration=50 # Max duration of malfunction
)
env = RailEnv(width=30,
height=30,
rail_generator=rail_generator,
line_generator=line_generator,
number_of_agents=10,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
remove_agents_at_target=True)
env.reset(True, True)
#env.save(file_name)
RailEnvPersister.save(env, file_name)
return env
[tox]
envlist = py36, py37, examples, notebooks, flake8
envlist = py37, py38, examples, docs, coverage
[travis]
python =
3.8: py38
3.7: py37
3.6: py36
[flake8]
max-line-length = 120
ignore = E121 E126 E123 E128 E133 E226 E241 E242 E704 W291 W293 W391 W503 W504 W505
[testenv:flake8]
basepython = python
basepython = python3.7
passenv = DISPLAY
deps =
-r{toxinidir}/requirements_dev.txt
......@@ -20,16 +21,15 @@ commands =
flake8 flatland tests examples benchmarks
[testenv:docs]
basepython = python
basepython = python3.7
whitelist_externals = make
passenv =
DISPLAY
HTTP_PROXY
HTTPS_PROXY
conda_deps =
cairosvg
pycairo
tk
graphviz
conda_channels :
conda-forge
anaconda
......@@ -41,7 +41,7 @@ commands =
make docs
[testenv:coverage]
basepython = python
basepython = python3.7
whitelist_externals = make
passenv =
DISPLAY
......@@ -49,8 +49,6 @@ passenv =
HTTP_PROXY
HTTPS_PROXY
conda_deps =
cairosvg
pycairo
tk
conda_channels :
conda-forge
......@@ -60,10 +58,10 @@ deps =
-r{toxinidir}/requirements_continuous_integration.txt
changedir = {toxinidir}
commands =
make coverage
python make_coverage.py
[testenv:benchmarks]
basepython = python
basepython = python3.7
setenv =
PYTHONPATH = {toxinidir}
passenv =
......@@ -76,12 +74,13 @@ whitelist_externals = sh
deps =
-r{toxinidir}/requirements_dev.txt
-r{toxinidir}/requirements_continuous_integration.txt
changedir = {toxinidir}
commands =
python --version
python benchmarks/benchmark_all_examples.py
python {toxinidir}/benchmarks/benchmark_all_examples.py
[testenv:profiling]
basepython = python
basepython = python3.7
setenv =
PYTHONPATH = {toxinidir}
passenv =
......@@ -91,8 +90,6 @@ passenv =
HTTP_PROXY
HTTPS_PROXY
conda_deps =
cairosvg
pycairo
tk
conda_channels :
conda-forge
......@@ -100,11 +97,13 @@ conda_channels :
deps =
-r{toxinidir}/requirements_dev.txt
-r{toxinidir}/requirements_continuous_integration.txt
changedir = {toxinidir}
commands =
python benchmarks/profile_all_examples.py
python {toxinidir}/benchmarks/profile_all_examples.py
[testenv:examples]
basepython = python
; TODO should examples be run with py36 and py37??
basepython = python3.7
setenv =
PYTHONPATH = {toxinidir}
passenv =
......@@ -114,8 +113,6 @@ passenv =
HTTP_PROXY
HTTPS_PROXY
conda_deps =
cairosvg
pycairo
tk
conda_channels :
conda-forge
......@@ -128,9 +125,11 @@ commands =
python {toxinidir}/benchmarks/run_all_examples.py
[testenv:notebooks]
basepython = python
; TODO should examples be run with py36 and py37??
basepython = python3.7
setenv =
PYTHONPATH = {toxinidir}
PYTHONPATH = {envdir}
;{toxinidir}
passenv =
DISPLAY
XAUTHORITY
......@@ -138,12 +137,12 @@ passenv =
HTTP_PROXY
HTTPS_PROXY
whitelist_externals = sh
bash
pwd
deps =
-r{toxinidir}/requirements_dev.txt
-r{toxinidir}/requirements_continuous_integration.txt
conda_deps =
cairosvg
pycairo
tk
conda_channels :
conda-forge
......@@ -151,6 +150,8 @@ conda_channels :
; run tests from subfolder to ensure that resources are accessed via resources and not via relative paths
changedir = {envtmpdir}/6f59bc68108c3895b1828abdd04b9a06
commands =
bash -c "pwd"
bash -c "echo $PYTHONPATH"
python -m jupyter nbextension install --py --sys-prefix widgetsnbextension
python -m jupyter nbextension enable --py --sys-prefix widgetsnbextension
python -m jupyter nbextension install --py --sys-prefix jpy_canvas
......@@ -158,7 +159,7 @@ commands =
python {toxinidir}/notebooks/run_all_notebooks.py
[testenv:start_jupyter]
basepython = python
basepython = python3.7
setenv =
PYTHONPATH = {toxinidir}
passenv =
......@@ -172,8 +173,6 @@ deps =
-r{toxinidir}/requirements_dev.txt
-r{toxinidir}/requirements_continuous_integration.txt
conda_deps =
cairosvg
pycairo
tk
conda_channels :
conda-forge
......@@ -186,8 +185,35 @@ commands =
python -m jupyter nbextension enable --py --sys-prefix jpy_canvas
python -m jupyter notebook
[testenv]
whitelist_externals = pip
[testenv:py37]
platform = linux|linux2|darwin
basepython = python3.7
setenv =
PYTHONPATH = {toxinidir}
passenv =
DISPLAY
XAUTHORITY
; HTTP_PROXY+HTTPS_PROXY required behind corporate proxies
HTTP_PROXY
HTTPS_PROXY
conda_deps =
tk
conda_channels :
conda-forge
anaconda
deps =
-r{toxinidir}/requirements_dev.txt
; run tests from subfolder to ensure that resources are accessed via resources and not via relative paths
changedir = {envtmpdir}/fefed3ba12bf1ed81dbcc20fb52706ea
commands =
python --version
python -m pytest --basetemp={envtmpdir} {toxinidir}
[testenv:py38]
platform = linux|linux2|darwin
basepython = python3.8
setenv =
PYTHONPATH = {toxinidir}
passenv =
......@@ -197,8 +223,6 @@ passenv =
HTTP_PROXY
HTTPS_PROXY
conda_deps =
cairosvg
pycairo
tk
conda_channels :
conda-forge
......