Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • flatland/flatland
  • stefan_otte/flatland
  • jiaodaxiaozi/flatland
  • sfwatergit/flatland
  • utozx126/flatland
  • ChenKuanSun/flatland
  • ashivani/flatland
  • minhhoa/flatland
  • pranjal_dhole/flatland
  • darthgera123/flatland
  • rivesunder/flatland
  • thomaslecat/flatland
  • joel_joseph/flatland
  • kchour/flatland
  • alex_zharichenko/flatland
  • yoogottamk/flatland
  • troye_fang/flatland
  • elrichgro/flatland
  • jun_jin/flatland
  • nimishsantosh107/flatland
20 results
Show changes
import pytest
@pytest.mark.skip(reason="Only for testing pettingzoo interface and wrappers")
def test_petting_zoo_interface_env():
import numpy as np
import os
import PIL
import shutil
from flatland.contrib.interface import flatland_env
from flatland.contrib.utils import env_generators
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
# First of all we import the Flatland rail environment
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from flatland.contrib.wrappers.flatland_wrappers import SkipNoChoiceCellsWrapper
from flatland.contrib.wrappers.flatland_wrappers import ShortestPathActionWrapper # noqa
# Custom observation builder without predictor
# observation_builder = GlobalObsForRailEnv()
# Custom observation builder with predictor
observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
seed = 11
save = True
np.random.seed(seed)
experiment_name = "flatland_pettingzoo"
total_episodes = 2
if save:
try:
if os.path.isdir(experiment_name):
shutil.rmtree(experiment_name)
os.mkdir(experiment_name)
except OSError as e:
print("Error: %s - %s." % (e.filename, e.strerror))
rail_env = env_generators.sparse_env_small(seed, observation_builder)
rail_env = env_generators.small_v0(seed, observation_builder)
rail_env.reset(random_seed=seed)
# For Shortest Path Action Wrapper, change action to 1
# rail_env = ShortestPathActionWrapper(rail_env)
rail_env = SkipNoChoiceCellsWrapper(rail_env, accumulate_skipped_rewards=False, discounting=0.0)
dones = {}
dones['__all__'] = False
step = 0
ep_no = 0
frame_list = []
all_actions_env = []
all_actions_pettingzoo_env = []
# while not dones['__all__']:
while ep_no < total_episodes:
action_dict = {}
# Chose an action for each agent
for a in range(rail_env.get_num_agents()):
# action = env_generators.get_shortest_path_action(rail_env, a)
action = 2
all_actions_env.append(action)
action_dict.update({a: action})
step += 1
# Do the environment step
observations, rewards, dones, information = rail_env.step(action_dict)
frame_list.append(PIL.Image.fromarray(rail_env.render(mode="rgb_array")))
if dones['__all__']:
completion = env_generators.perc_completion(rail_env)
print("Final Agents Completed:", completion)
ep_no += 1
if save:
frame_list[0].save(f"{experiment_name}{os.sep}out_{ep_no}.gif", save_all=True,
append_images=frame_list[1:], duration=3, loop=0)
frame_list = []
rail_env.reset(random_seed=seed+ep_no)
# __sphinx_doc_begin__
env = flatland_env.env(environment=rail_env)
seed = 11
env.reset(random_seed=seed)
step = 0
ep_no = 0
frame_list = []
while ep_no < total_episodes:
for agent in env.agent_iter():
obs, reward, done, info = env.last()
# act = env_generators.get_shortest_path_action(env.environment, get_agent_handle(agent))
act = 2
all_actions_pettingzoo_env.append(act)
env.step(act)
frame_list.append(PIL.Image.fromarray(env.render(mode='rgb_array')))
step += 1
# __sphinx_doc_end__
completion = env_generators.perc_completion(env)
print("Final Agents Completed:", completion)
ep_no += 1
if save:
frame_list[0].save(f"{experiment_name}{os.sep}pettyzoo_out_{ep_no}.gif", save_all=True,
append_images=frame_list[1:], duration=3, loop=0)
frame_list = []
env.close()
env.reset(random_seed=seed+ep_no)
min_len = min(len(all_actions_pettingzoo_env), len(all_actions_env))
assert all_actions_pettingzoo_env[:min_len] == all_actions_env[:min_len], "actions do not match"
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-sv", __file__]))
...@@ -3,24 +3,20 @@ import numpy as np ...@@ -3,24 +3,20 @@ import numpy as np
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.rail_generators import rail_from_grid_transition_map, sparse_rail_generator
from flatland.envs.schedule_generators import random_schedule_generator from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_simple_rail2 from flatland.utils.simple_rail import make_simple_rail2
def test_random_seeding(): def ndom_seeding():
# Set fixed malfunction duration for this test # Set fixed malfunction duration for this test
rail, rail_map = make_simple_rail2() rail, rail_map, optionals = make_simple_rail2()
# Move target to unreachable position in order to not interfere with test # Move target to unreachable position in order to not interfere with test
for idx in range(100): for idx in range(100):
env = RailEnv(width=25, env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
height=30, line_generator=sparse_line_generator(seed=12), number_of_agents=10)
rail_generator=rail_from_grid_transition_map(rail), env.reset(True, True, random_seed=1)
schedule_generator=random_schedule_generator(seed=12),
number_of_agents=10
)
env.reset(True, True, False, random_seed=1)
env.agents[0].target = (0, 0) env.agents[0].target = (0, 0)
for step in range(10): for step in range(10):
...@@ -48,29 +44,20 @@ def test_random_seeding(): ...@@ -48,29 +44,20 @@ def test_random_seeding():
def test_seeding_and_observations(): def test_seeding_and_observations():
# Test if two different instances diverge with different observations # Test if two different instances diverge with different observations
rail, rail_map = make_simple_rail2() rail, rail_map, optionals = make_simple_rail2()
optionals['agents_hints']['num_agents'] = 10
# Make two seperate envs with different observation builders # Make two seperate envs with different observation builders
# Global Observation # Global Observation
env = RailEnv(width=25, env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
height=30, line_generator=sparse_line_generator(seed=12), number_of_agents=10,
rail_generator=rail_from_grid_transition_map(rail), obs_builder_object=GlobalObsForRailEnv())
schedule_generator=random_schedule_generator(seed=12),
number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv()
)
# Tree Observation # Tree Observation
env2 = RailEnv(width=25, env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
height=30, line_generator=sparse_line_generator(seed=12), number_of_agents=10,
rail_generator=rail_from_grid_transition_map(rail), obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
schedule_generator=random_schedule_generator(seed=12),
number_of_agents=10,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
)
env.reset(False, False, False, random_seed=12)
env2.reset(False, False, False, random_seed=12)
env.reset(False, False, random_seed=12)
env2.reset(False, False, random_seed=12)
# Check that both environments produce the same initial start positions # Check that both environments produce the same initial start positions
assert env.agents[0].initial_position == env2.agents[0].initial_position assert env.agents[0].initial_position == env2.agents[0].initial_position
assert env.agents[1].initial_position == env2.agents[1].initial_position assert env.agents[1].initial_position == env2.agents[1].initial_position
...@@ -90,9 +77,7 @@ def test_seeding_and_observations(): ...@@ -90,9 +77,7 @@ def test_seeding_and_observations():
action_dict[a] = action action_dict[a] = action
env.step(action_dict) env.step(action_dict)
env2.step(action_dict) env2.step(action_dict)
# Check that both environments end up in the same position # Check that both environments end up in the same position
assert env.agents[0].position == env2.agents[0].position assert env.agents[0].position == env2.agents[0].position
assert env.agents[1].position == env2.agents[1].position assert env.agents[1].position == env2.agents[1].position
assert env.agents[2].position == env2.agents[2].position assert env.agents[2].position == env2.agents[2].position
...@@ -109,8 +94,8 @@ def test_seeding_and_observations(): ...@@ -109,8 +94,8 @@ def test_seeding_and_observations():
def test_seeding_and_malfunction(): def test_seeding_and_malfunction():
# Test if two different instances diverge with different observations # Test if two different instances diverge with different observations
rail, rail_map = make_simple_rail2() rail, rail_map, optionals = make_simple_rail2()
optionals['agents_hints']['num_agents'] = 10
stochastic_data = {'prop_malfunction': 0.4, stochastic_data = {'prop_malfunction': 0.4,
'malfunction_rate': 2, 'malfunction_rate': 2,
'min_duration': 10, 'min_duration': 10,
...@@ -118,27 +103,17 @@ def test_seeding_and_malfunction(): ...@@ -118,27 +103,17 @@ def test_seeding_and_malfunction():
# Make two seperate envs with different and see if the exhibit the same malfunctions # Make two seperate envs with different and see if the exhibit the same malfunctions
# Global Observation # Global Observation
for tests in range(1, 100): for tests in range(1, 100):
env = RailEnv(width=25, env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
height=30, line_generator=sparse_line_generator(), number_of_agents=10,
rail_generator=rail_from_grid_transition_map(rail), obs_builder_object=GlobalObsForRailEnv())
schedule_generator=random_schedule_generator(),
number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv(),
stochastic_data=stochastic_data, # Malfunction data generator
)
# Tree Observation # Tree Observation
env2 = RailEnv(width=25, env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
height=30, line_generator=sparse_line_generator(), number_of_agents=10,
rail_generator=rail_from_grid_transition_map(rail), obs_builder_object=GlobalObsForRailEnv())
schedule_generator=random_schedule_generator(),
number_of_agents=10, env.reset(True, False, random_seed=tests)
obs_builder_object=GlobalObsForRailEnv(), env2.reset(True, False, random_seed=tests)
stochastic_data=stochastic_data, # Malfunction data generator
)
env.reset(True, False, True, random_seed=tests)
env2.reset(True, False, True, random_seed=tests)
# Check that both environments produce the same initial start positions # Check that both environments produce the same initial start positions
assert env.agents[0].initial_position == env2.agents[0].initial_position assert env.agents[0].initial_position == env2.agents[0].initial_position
...@@ -158,8 +133,8 @@ def test_seeding_and_malfunction(): ...@@ -158,8 +133,8 @@ def test_seeding_and_malfunction():
action = np.random.randint(4) action = np.random.randint(4)
action_dict[a] = action action_dict[a] = action
# print("----------------------") # print("----------------------")
# print(env.agents[a].malfunction_data, env.agents[a].status) # print(env.agents[a].malfunction_handler, env.agents[a].status)
# print(env2.agents[a].malfunction_data, env2.agents[a].status) # print(env2.agents[a].malfunction_handler, env2.agents[a].status)
_, reward1, done1, _ = env.step(action_dict) _, reward1, done1, _ = env.step(action_dict)
_, reward2, done2, _ = env2.step(action_dict) _, reward2, done2, _ = env2.step(action_dict)
...@@ -178,3 +153,66 @@ def test_seeding_and_malfunction(): ...@@ -178,3 +153,66 @@ def test_seeding_and_malfunction():
assert env.agents[7].position == env2.agents[7].position assert env.agents[7].position == env2.agents[7].position
assert env.agents[8].position == env2.agents[8].position assert env.agents[8].position == env2.agents[8].position
assert env.agents[9].position == env2.agents[9].position assert env.agents[9].position == env2.agents[9].position
def test_reproducability_env():
"""
Test that no random generators are present within the env that get influenced by external np random
"""
speed_ration_map = {1.: 1., # Fast passenger train
1. / 2.: 0., # Fast freight train
1. / 3.: 0., # Slow commuter train
1. / 4.: 0.} # Slow freight train
env = RailEnv(width=25, height=30, rail_generator=sparse_rail_generator(max_num_cities=5,
max_rails_between_cities=3,
seed=10, # Random seed
grid_mode=True
),
line_generator=sparse_line_generator(speed_ration_map), number_of_agents=1)
env.reset(True, True, random_seed=1)
excpeted_grid = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 16386, 1025, 4608, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[16386, 17411, 1025, 5633, 17411, 3089, 1025, 1097, 5633, 17411, 1025, 5633, 1025, 1025, 1025, 1025, 5633, 17411, 1025, 1025, 1025, 5633, 17411, 1025, 4608],
[32800, 32800, 0, 72, 3089, 5633, 1025, 17411, 1097, 2064, 0, 72, 1025, 1025, 1025, 1025, 1097, 3089, 1025, 1025, 1025, 1097, 3089, 1025, 37408],
[32800, 32800, 0, 0, 0, 72, 1025, 2064, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[32800, 32872, 4608, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16386, 34864],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[72, 37408, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[0, 49186, 2064, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 72, 37408],
[0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[0, 32872, 1025, 5633, 17411, 1025, 1025, 1025, 5633, 17411, 1025, 1025, 1025, 1025, 1025, 1025, 5633, 17411, 1025, 1025, 1025, 5633, 17411, 1025, 34864],
[0, 72, 1025, 1097, 3089, 1025, 1025, 1025, 1097, 3089, 1025, 1025, 1025, 1025, 1025, 1025, 1097, 3089, 1025, 1025, 1025, 1097, 3089, 1025, 2064],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
assert env.rail.grid.tolist() == excpeted_grid
# Test that we don't have interference from calling mulitple function outisde
env2 = RailEnv(width=25, height=30, rail_generator=sparse_rail_generator(max_num_cities=5,
max_rails_between_cities=3,
seed=10, # Random seed
grid_mode=True
),
line_generator=sparse_line_generator(speed_ration_map), number_of_agents=1)
np.random.seed(1)
for i in range(10):
np.random.randn()
env2.reset(True, True, random_seed=1)
assert env2.rail.grid.tolist() == excpeted_grid
...@@ -2,30 +2,28 @@ ...@@ -2,30 +2,28 @@
import numpy as np import numpy as np
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import speed_initialization_helper, complex_schedule_generator from flatland.envs.line_generators import speed_initialization_helper, sparse_line_generator
def test_speed_initialization_helper(): def test_speed_initialization_helper():
np.random.seed(1) random_generator = np.random.RandomState()
random_generator.seed(10)
speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.3} speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.3}
actual_speeds = speed_initialization_helper(10, speed_ratio_map) actual_speeds = speed_initialization_helper(10, speed_ratio_map, np_random=random_generator)
# seed makes speed_initialization_helper deterministic -> check generated speeds. # seed makes speed_initialization_helper deterministic -> check generated speeds.
assert actual_speeds == [2, 3, 1, 2, 1, 1, 1, 2, 2, 2] assert actual_speeds == [3, 1, 2, 3, 2, 1, 1, 3, 1, 1]
def test_rail_env_speed_intializer(): def test_rail_env_speed_intializer():
speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.1, 5: 0.2} speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.1, 5: 0.2}
env = RailEnv(width=50, env = RailEnv(width=50, height=50,
height=50, rail_generator=sparse_rail_generator(), line_generator=sparse_line_generator(),
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999,
seed=1),
schedule_generator=complex_schedule_generator(),
number_of_agents=10) number_of_agents=10)
env.reset() env.reset()
actual_speeds = list(map(lambda agent: agent.speed_data['speed'], env.agents)) actual_speeds = list(map(lambda agent: agent.speed_counter.speed, env.agents))
expected_speed_set = set(speed_ratio_map.keys()) expected_speed_set = set(speed_ratio_map.keys())
......
...@@ -5,10 +5,15 @@ import numpy as np ...@@ -5,10 +5,15 @@ import numpy as np
from attr import attrs, attrib from attr import attrs, attrib
from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.agent_utils import EnvAgent, RailAgentStatus from flatland.envs.agent_utils import EnvAgent
from flatland.envs.malfunction_generators import MalfunctionParameters, malfunction_from_params
from flatland.envs.rail_env import RailEnvActions, RailEnv from flatland.envs.rail_env import RailEnvActions, RailEnv
from flatland.envs.rail_generators import RailGenerator
from flatland.envs.line_generators import LineGenerator
from flatland.utils.rendertools import RenderTool from flatland.utils.rendertools import RenderTool
from flatland.envs.persistence import RailEnvPersister
from flatland.envs.step_utils.states import TrainState
from flatland.envs.step_utils.speed_counter import SpeedCounter
@attrs @attrs
class Replay(object): class Replay(object):
...@@ -18,7 +23,7 @@ class Replay(object): ...@@ -18,7 +23,7 @@ class Replay(object):
malfunction = attrib(default=0, type=int) malfunction = attrib(default=0, type=int)
set_malfunction = attrib(default=None, type=Optional[int]) set_malfunction = attrib(default=None, type=Optional[int])
reward = attrib(default=None, type=Optional[float]) reward = attrib(default=None, type=Optional[float])
status = attrib(default=None, type=Optional[RailAgentStatus]) state = attrib(default=None, type=Optional[TrainState])
@attrs @attrs
...@@ -38,7 +43,8 @@ def set_penalties_for_replay(env: RailEnv): ...@@ -38,7 +43,8 @@ def set_penalties_for_replay(env: RailEnv):
env.invalid_action_penalty = -29 env.invalid_action_penalty = -29
def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False, activate_agents=True): def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False, activate_agents=True,
skip_reward_check=False, set_ready_to_depart=False, skip_action_required_check=False):
""" """
Runs the replay configs and checks assertions. Runs the replay configs and checks assertions.
...@@ -83,8 +89,19 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: ...@@ -83,8 +89,19 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
agent.initial_direction = test_config.initial_direction agent.initial_direction = test_config.initial_direction
agent.direction = test_config.initial_direction agent.direction = test_config.initial_direction
agent.target = test_config.target agent.target = test_config.target
agent.speed_data['speed'] = test_config.speed agent.speed_counter = SpeedCounter(speed=test_config.speed)
env.reset(False, False, activate_agents) env.reset(False, False)
if set_ready_to_depart:
# Set all agents to ready to depart
for i_agent in range(len(env.agents)):
env.agents[i_agent].earliest_departure = 0
env.agents[i_agent]._set_state(TrainState.READY_TO_DEPART)
elif activate_agents:
for a_idx in range(len(env.agents)):
env.agents[a_idx].position = env.agents[a_idx].initial_position
env.agents[a_idx]._set_state(TrainState.MOVING)
def _assert(a, actual, expected, msg): def _assert(a, actual, expected, msg):
print("[{}] verifying {} on agent {}: actual={}, expected={}".format(step, msg, a, actual, expected)) print("[{}] verifying {} on agent {}: actual={}, expected={}".format(step, msg, a, actual, expected))
...@@ -98,19 +115,22 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: ...@@ -98,19 +115,22 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
for a, test_config in enumerate(test_configs): for a, test_config in enumerate(test_configs):
agent: EnvAgent = env.agents[a] agent: EnvAgent = env.agents[a]
replay = test_config.replay[step] replay = test_config.replay[step]
# if not agent.position == replay.position:
# import pdb; pdb.set_trace()
_assert(a, agent.position, replay.position, 'position') _assert(a, agent.position, replay.position, 'position')
_assert(a, agent.direction, replay.direction, 'direction') _assert(a, agent.direction, replay.direction, 'direction')
if replay.status is not None: if replay.state is not None:
_assert(a, agent.status, replay.status, 'status') _assert(a, agent.state, replay.state, 'state')
if replay.action is not None: if replay.action is not None:
assert info_dict['action_required'][ if not skip_action_required_check:
a] == True or agent.status == RailAgentStatus.READY_TO_DEPART, "[{}] agent {} expecting action_required={} or agent status READY_TO_DEPART".format( assert info_dict['action_required'][
a] == True or agent.state == TrainState.READY_TO_DEPART, "[{}] agent {} expecting action_required={} or agent status READY_TO_DEPART".format(
step, a, True) step, a, True)
action_dict[a] = replay.action action_dict[a] = replay.action
else: else:
assert info_dict['action_required'][ if not skip_action_required_check:
assert info_dict['action_required'][
a] == False, "[{}] agent {} expecting action_required={}, but found {}".format( a] == False, "[{}] agent {} expecting action_required={}, but found {}".format(
step, a, False, info_dict['action_required'][a]) step, a, False, info_dict['action_required'][a])
...@@ -118,16 +138,34 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: ...@@ -118,16 +138,34 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
# As we force malfunctions on the agents we have to set a positive rate that the env # As we force malfunctions on the agents we have to set a positive rate that the env
# recognizes the agent as potentially malfuncitoning # recognizes the agent as potentially malfuncitoning
# We also set next malfunction to infitiy to avoid interference with our tests # We also set next malfunction to infitiy to avoid interference with our tests
agent.malfunction_data['malfunction'] = replay.set_malfunction env.agents[a].malfunction_handler._set_malfunction_down_counter(replay.set_malfunction)
agent.malfunction_data['moving_before_malfunction'] = agent.moving _assert(a, agent.malfunction_handler.malfunction_down_counter, replay.malfunction, 'malfunction')
agent.malfunction_data['fixed'] = False
_assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
print(step) print(step)
_, rewards_dict, _, info_dict = env.step(action_dict) _, rewards_dict, _, info_dict = env.step(action_dict)
# import pdb; pdb.set_trace()
if rendering: if rendering:
renderer.render_env(show=True, show_observations=True) renderer.render_env(show=True, show_observations=True)
for a, test_config in enumerate(test_configs): for a, test_config in enumerate(test_configs):
replay = test_config.replay[step] replay = test_config.replay[step]
_assert(a, rewards_dict[a], replay.reward, 'reward') if not skip_reward_check:
_assert(a, rewards_dict[a], replay.reward, 'reward')
def create_and_save_env(file_name: str, line_generator: LineGenerator, rail_generator: RailGenerator):
stochastic_data = MalfunctionParameters(malfunction_rate=1000, # Rate of malfunction occurence
min_duration=15, # Minimal duration of malfunction
max_duration=50 # Max duration of malfunction
)
env = RailEnv(width=30,
height=30,
rail_generator=rail_generator,
line_generator=line_generator,
number_of_agents=10,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
remove_agents_at_target=True)
env.reset(True, True)
#env.save(file_name)
RailEnvPersister.save(env, file_name)
return env
[tox] [tox]
envlist = py36, py37, examples, notebooks, flake8, docs, coverage envlist = py37, py38, examples, docs, coverage
[travis] [travis]
python = python =
3.8: py38
3.7: py37 3.7: py37
3.6: py36
[flake8] [flake8]
max-line-length = 120 max-line-length = 120
ignore = E121 E126 E123 E128 E133 E226 E241 E242 E704 W291 W293 W391 W503 W504 W505 ignore = E121 E126 E123 E128 E133 E226 E241 E242 E704 W291 W293 W391 W503 W504 W505
[testenv:flake8] [testenv:flake8]
basepython = python basepython = python3.7
passenv = DISPLAY passenv = DISPLAY
deps = deps =
-r{toxinidir}/requirements_dev.txt -r{toxinidir}/requirements_dev.txt
...@@ -21,16 +21,13 @@ commands = ...@@ -21,16 +21,13 @@ commands =
flake8 flatland tests examples benchmarks flake8 flatland tests examples benchmarks
[testenv:docs] [testenv:docs]
; use python3.6 because of incompatibility under Windows of the pycairo installed through conda for py37 basepython = python3.7
basepython = python3.6
whitelist_externals = make whitelist_externals = make
passenv = passenv =
DISPLAY DISPLAY
HTTP_PROXY HTTP_PROXY
HTTPS_PROXY HTTPS_PROXY
conda_deps = conda_deps =
cairosvg
pycairo
tk tk
graphviz graphviz
conda_channels : conda_channels :
...@@ -44,8 +41,7 @@ commands = ...@@ -44,8 +41,7 @@ commands =
make docs make docs
[testenv:coverage] [testenv:coverage]
; use python3.6 because of incompatibility under Windows of the pycairo installed through conda for py37 basepython = python3.7
basepython = python3.6
whitelist_externals = make whitelist_externals = make
passenv = passenv =
DISPLAY DISPLAY
...@@ -53,8 +49,6 @@ passenv = ...@@ -53,8 +49,6 @@ passenv =
HTTP_PROXY HTTP_PROXY
HTTPS_PROXY HTTPS_PROXY
conda_deps = conda_deps =
cairosvg
pycairo
tk tk
conda_channels : conda_channels :
conda-forge conda-forge
...@@ -67,8 +61,7 @@ commands = ...@@ -67,8 +61,7 @@ commands =
python make_coverage.py python make_coverage.py
[testenv:benchmarks] [testenv:benchmarks]
; use python3.6 because of incompatibility under Windows of the pycairo installed through conda for py37 basepython = python3.7
basepython = python3.6
setenv = setenv =
PYTHONPATH = {toxinidir} PYTHONPATH = {toxinidir}
passenv = passenv =
...@@ -81,14 +74,13 @@ whitelist_externals = sh ...@@ -81,14 +74,13 @@ whitelist_externals = sh
deps = deps =
-r{toxinidir}/requirements_dev.txt -r{toxinidir}/requirements_dev.txt
-r{toxinidir}/requirements_continuous_integration.txt -r{toxinidir}/requirements_continuous_integration.txt
changedir = {toxinidir} changedir = {toxinidir}
commands = commands =
python --version python --version
python {toxinidir}/benchmarks/benchmark_all_examples.py python {toxinidir}/benchmarks/benchmark_all_examples.py
[testenv:profiling] [testenv:profiling]
; use python3.6 because of incompatibility under Windows of the pycairo installed through conda for py37 basepython = python3.7
basepython = python3.6
setenv = setenv =
PYTHONPATH = {toxinidir} PYTHONPATH = {toxinidir}
passenv = passenv =
...@@ -98,8 +90,6 @@ passenv = ...@@ -98,8 +90,6 @@ passenv =
HTTP_PROXY HTTP_PROXY
HTTPS_PROXY HTTPS_PROXY
conda_deps = conda_deps =
cairosvg
pycairo
tk tk
conda_channels : conda_channels :
conda-forge conda-forge
...@@ -107,14 +97,13 @@ conda_channels : ...@@ -107,14 +97,13 @@ conda_channels :
deps = deps =
-r{toxinidir}/requirements_dev.txt -r{toxinidir}/requirements_dev.txt
-r{toxinidir}/requirements_continuous_integration.txt -r{toxinidir}/requirements_continuous_integration.txt
changedir = {toxinidir} changedir = {toxinidir}
commands = commands =
python {toxinidir}/benchmarks/profile_all_examples.py python {toxinidir}/benchmarks/profile_all_examples.py
[testenv:examples] [testenv:examples]
; TODO should examples be run with py36 and py37?? ; TODO should examples be run with py36 and py37??
; use python3.6 because of incompatibility under Windows of the pycairo installed through conda for py37 basepython = python3.7
basepython = python3.6
setenv = setenv =
PYTHONPATH = {toxinidir} PYTHONPATH = {toxinidir}
passenv = passenv =
...@@ -124,8 +113,6 @@ passenv = ...@@ -124,8 +113,6 @@ passenv =
HTTP_PROXY HTTP_PROXY
HTTPS_PROXY HTTPS_PROXY
conda_deps = conda_deps =
cairosvg
pycairo
tk tk
conda_channels : conda_channels :
conda-forge conda-forge
...@@ -139,10 +126,10 @@ commands = ...@@ -139,10 +126,10 @@ commands =
[testenv:notebooks] [testenv:notebooks]
; TODO should examples be run with py36 and py37?? ; TODO should examples be run with py36 and py37??
; use python3.6 because of incompatibility under Windows of the pycairo installed through conda for py37 basepython = python3.7
basepython = python3.6
setenv = setenv =
PYTHONPATH = {toxinidir} PYTHONPATH = {envdir}
;{toxinidir}
passenv = passenv =
DISPLAY DISPLAY
XAUTHORITY XAUTHORITY
...@@ -150,12 +137,12 @@ passenv = ...@@ -150,12 +137,12 @@ passenv =
HTTP_PROXY HTTP_PROXY
HTTPS_PROXY HTTPS_PROXY
whitelist_externals = sh whitelist_externals = sh
bash
pwd
deps = deps =
-r{toxinidir}/requirements_dev.txt -r{toxinidir}/requirements_dev.txt
-r{toxinidir}/requirements_continuous_integration.txt -r{toxinidir}/requirements_continuous_integration.txt
conda_deps = conda_deps =
cairosvg
pycairo
tk tk
conda_channels : conda_channels :
conda-forge conda-forge
...@@ -163,6 +150,8 @@ conda_channels : ...@@ -163,6 +150,8 @@ conda_channels :
; run tests from subfolder to ensure that resources are accessed via resources and not via relative paths ; run tests from subfolder to ensure that resources are accessed via resources and not via relative paths
changedir = {envtmpdir}/6f59bc68108c3895b1828abdd04b9a06 changedir = {envtmpdir}/6f59bc68108c3895b1828abdd04b9a06
commands = commands =
bash -c "pwd"
bash -c "echo $PYTHONPATH"
python -m jupyter nbextension install --py --sys-prefix widgetsnbextension python -m jupyter nbextension install --py --sys-prefix widgetsnbextension
python -m jupyter nbextension enable --py --sys-prefix widgetsnbextension python -m jupyter nbextension enable --py --sys-prefix widgetsnbextension
python -m jupyter nbextension install --py --sys-prefix jpy_canvas python -m jupyter nbextension install --py --sys-prefix jpy_canvas
...@@ -170,8 +159,7 @@ commands = ...@@ -170,8 +159,7 @@ commands =
python {toxinidir}/notebooks/run_all_notebooks.py python {toxinidir}/notebooks/run_all_notebooks.py
[testenv:start_jupyter] [testenv:start_jupyter]
; use python3.6 because of incompatibility under Windows of the pycairo installed through conda for py37 basepython = python3.7
basepython = python3.6
setenv = setenv =
PYTHONPATH = {toxinidir} PYTHONPATH = {toxinidir}
passenv = passenv =
...@@ -185,8 +173,6 @@ deps = ...@@ -185,8 +173,6 @@ deps =
-r{toxinidir}/requirements_dev.txt -r{toxinidir}/requirements_dev.txt
-r{toxinidir}/requirements_continuous_integration.txt -r{toxinidir}/requirements_continuous_integration.txt
conda_deps = conda_deps =
cairosvg
pycairo
tk tk
conda_channels : conda_channels :
conda-forge conda-forge
...@@ -199,7 +185,9 @@ commands = ...@@ -199,7 +185,9 @@ commands =
python -m jupyter nbextension enable --py --sys-prefix jpy_canvas python -m jupyter nbextension enable --py --sys-prefix jpy_canvas
python -m jupyter notebook python -m jupyter notebook
[testenv] [testenv:py37]
platform = linux|linux2|darwin
basepython = python3.7
setenv = setenv =
PYTHONPATH = {toxinidir} PYTHONPATH = {toxinidir}
passenv = passenv =
...@@ -209,8 +197,6 @@ passenv = ...@@ -209,8 +197,6 @@ passenv =
HTTP_PROXY HTTP_PROXY
HTTPS_PROXY HTTPS_PROXY
conda_deps = conda_deps =
cairosvg
pycairo
tk tk
conda_channels : conda_channels :
conda-forge conda-forge
...@@ -223,9 +209,11 @@ commands = ...@@ -223,9 +209,11 @@ commands =
python --version python --version
python -m pytest --basetemp={envtmpdir} {toxinidir} python -m pytest --basetemp={envtmpdir} {toxinidir}
[testenv:py37]
; exclude py37 from Windows because of incompatibility the pycairo installed through conda for py37
[testenv:py38]
platform = linux|linux2|darwin platform = linux|linux2|darwin
basepython = python3.8
setenv = setenv =
PYTHONPATH = {toxinidir} PYTHONPATH = {toxinidir}
passenv = passenv =
...@@ -235,8 +223,6 @@ passenv = ...@@ -235,8 +223,6 @@ passenv =
HTTP_PROXY HTTP_PROXY
HTTPS_PROXY HTTPS_PROXY
conda_deps = conda_deps =
cairosvg
pycairo
tk tk
conda_channels : conda_channels :
conda-forge conda-forge
......