Commit ae31a7b8 authored by Erik Nygren's avatar Erik Nygren 🚅
Browse files

updated tests

parent 8fb2fb6f
Pipeline #2710 passed with stages
in 33 minutes and 37 seconds
......@@ -74,8 +74,13 @@ observation_builder = GlobalObsForRailEnv()
# observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
# Construct the enviornment with the given observation, generataors, predictors, and stochastic data
env = RailEnv(width=width, height=height, rail_generator=rail_generator, schedule_generator=schedule_generator,
number_of_agents=nr_trains, obs_builder_object=observation_builder, malfunction_generator=malfunction_from_params(stochastic_data),
env = RailEnv(width=width,
height=height,
rail_generator=rail_generator,
schedule_generator=schedule_generator,
number_of_agents=nr_trains,
obs_builder_object=observation_builder,
malfunction_generator=malfunction_from_params(stochastic_data),
remove_agents_at_target=True)
env.reset()
......
......@@ -17,4 +17,4 @@ env_renderer = RenderTool(env)
env_renderer.render_env(show=True, show_predictions=False, show_observations=False)
# uncomment to keep the renderer open
#input("Press Enter to continue...")
# input("Press Enter to continue...")
......@@ -33,4 +33,4 @@ env_renderer = RenderTool(env, gl="PIL")
env_renderer.render_env(show=True)
# uncomment to keep the renderer open
#input("Press Enter to continue...")
# input("Press Enter to continue...")
......@@ -160,6 +160,7 @@ def fix_inner_nodes(grid_map: GridTransitionMap, inner_node_pos: IntVector2D, ra
grid_map.grid[tmp_pos] = transition
return
def align_cell_to_city(city_center, city_orientation, cell):
"""
Alig all cells to face the city center along the city orientation
......@@ -171,4 +172,4 @@ def align_cell_to_city(city_center, city_orientation, cell):
if city_orientation % 2 == 0:
return int(2 * np.clip(cell[0] - city_center[0], 0, 1))
else:
return int(2 * np.clip(city_center[1] - cell[1], 0, 1)) + 1
return int(2 * np.clip(city_center[1] - cell[1], 0, 1)) + 1
"""Malfunction generators for rail systems"""
from typing import Tuple, List, Callable
from typing import Tuple, Callable
import msgpack
......@@ -36,6 +36,7 @@ def malfunction_from_file(filename) -> MalfunctionGenerator:
return generator
def malfunction_from_params(parameters) -> MalfunctionGenerator:
"""
Utility to load malfunction from parameters
......@@ -60,6 +61,7 @@ def malfunction_from_params(parameters) -> MalfunctionGenerator:
return generator
def no_malfunction_generator() -> MalfunctionGenerator:
"""
Utility to load malfunction from parameters
......
......@@ -994,4 +994,3 @@ class RailEnv(Environment):
"""
return agent.malfunction_data['malfunction'] < 1
......@@ -345,7 +345,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11, seed=1) -> R
def get_matching_templates(template):
"""
Returns a list of possible transition maps for a given template
Parameters:
------
template:List[int]
......@@ -751,7 +751,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
# Respect padding between cities
padding = 2
city_size = 2 * (city_radius + 1)
max_cities_per_row =int((height - padding) // city_size)
max_cities_per_row = int((height - padding) // city_size)
max_cities_per_col = int((width - padding) // city_size)
# Choose number of cities per row.
......
......@@ -204,7 +204,7 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] =
if len(valid_positions) < num_agents:
warnings.warn("schedule_generators: len(valid_positions) < num_agents")
return Schedule(agent_positions=[], agent_directions=[],
agent_targets=[], agent_speeds=[], agent_malfunction_rates=None)
agent_targets=[], agent_speeds=[], agent_malfunction_rates=None)
agents_position_idx = [i for i in np.random.choice(len(valid_positions), num_agents, replace=False)]
agents_position = [valid_positions[agents_position_idx[i]] for i in range(num_agents)]
......
......@@ -30,14 +30,14 @@ def get_boto_client():
import boto3
except ImportError as e:
raise Exception(
"boto3 is not installed. Please manually install by : ",
" pip install -U boto3"
)
"boto3 is not installed. Please manually install by : ",
" pip install -U boto3"
)
return boto3.client(
's3',
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY
's3',
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY
)
......@@ -50,7 +50,7 @@ def is_aws_configured():
def is_grading():
return os.getenv("CROWDAI_IS_GRADING", False) or \
os.getenv("AICROWD_IS_GRADING", False)
os.getenv("AICROWD_IS_GRADING", False)
def upload_random_frame_to_s3(frames_folder):
......@@ -61,7 +61,7 @@ def upload_random_frame_to_s3(frames_folder):
raise Exception("S3_UPLOAD_PATH_TEMPLATE not provided...")
if not S3_BUCKET:
raise Exception("S3_BUCKET not provided...")
image_target_key = S3_UPLOAD_PATH_TEMPLATE.replace(".mp4", ".png").format(str(uuid.uuid4()))
s3.put_object(
ACL="public-read",
......@@ -78,7 +78,7 @@ def upload_to_s3(localpath):
raise Exception("S3_UPLOAD_PATH_TEMPLATE not provided...")
if not S3_BUCKET:
raise Exception("S3_BUCKET not provided...")
image_target_key = S3_UPLOAD_PATH_TEMPLATE.format(str(uuid.uuid4()))
s3.put_object(
ACL="public-read",
......@@ -91,11 +91,11 @@ def upload_to_s3(localpath):
def make_subprocess_call(command, shell=False):
result = subprocess.run(
command.split(),
shell=shell,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
command.split(),
shell=shell,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
stdout = result.stdout.decode('utf-8')
stderr = result.stderr.decode('utf-8')
return result.returncode, stdout, stderr
......@@ -103,7 +103,7 @@ def make_subprocess_call(command, shell=False):
def generate_movie_from_frames(frames_folder):
"""
Expects the frames in the frames_folder folder
Expects the frames in the frames_folder folder
and then use ffmpeg to generate the video
which writes the output to the frames_folder
"""
......@@ -112,9 +112,9 @@ def generate_movie_from_frames(frames_folder):
frames_path = os.path.join(frames_folder, "flatland_frame_%04d.png")
thumb_output_path = os.path.join(frames_folder, "out_thumb.mp4")
return_code, output, output_err = make_subprocess_call(
"ffmpeg -r 7 -start_number 0 -i " +
frames_path +
" -c:v libx264 -vf fps=7 -pix_fmt yuv420p -s 320x320 " +
"ffmpeg -r 7 -start_number 0 -i " +
frames_path +
" -c:v libx264 -vf fps=7 -pix_fmt yuv420p -s 320x320 " +
thumb_output_path
)
if return_code != 0:
......@@ -125,13 +125,12 @@ def generate_movie_from_frames(frames_folder):
frames_path = os.path.join(frames_folder, "flatland_frame_%04d.png")
output_path = os.path.join(frames_folder, "out.mp4")
return_code, output, output_err = make_subprocess_call(
"ffmpeg -r 7 -start_number 0 -i " +
frames_path +
" -c:v libx264 -vf fps=7 -pix_fmt yuv420p -s 600x600 " +
"ffmpeg -r 7 -start_number 0 -i " +
frames_path +
" -c:v libx264 -vf fps=7 -pix_fmt yuv420p -s 600x600 " +
output_path
)
if return_code != 0:
raise Exception(output_err)
return output_path, thumb_output_path
......@@ -11,8 +11,6 @@ import numpy as np
import redis
import flatland
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file
from flatland.envs.schedule_generators import schedule_from_file
......@@ -223,11 +221,11 @@ class FlatlandRemoteClient(object):
time_start = time.time()
local_observation, info = self.env.reset(
regenerate_rail=True,
regenerate_schedule=True,
activate_agents=False,
random_seed=random_seed
)
regenerate_rail=True,
regenerate_schedule=True,
activate_agents=False,
random_seed=random_seed
)
time_diff = time.time() - time_start
self.update_running_mean_stats("internal_env_reset_time", time_diff)
# Use the local observation
......@@ -266,14 +264,14 @@ class FlatlandRemoteClient(object):
######################################################################
# Print Local Stats
######################################################################
print("="*100)
print("="*100)
print("=" * 100)
print("=" * 100)
print("## Client Performance Stats")
print("="*100)
print("=" * 100)
for _key in self.stats:
if _key.endswith("_mean"):
print("\t - {}\t:{}".format(_key, self.stats[_key]))
print("="*100)
print("=" * 100)
if os.getenv("AICROWD_BLOCKING_SUBMIT"):
"""
If the submission is supposed to happen as a blocking submit,
......@@ -288,12 +286,14 @@ class FlatlandRemoteClient(object):
if __name__ == "__main__":
remote_client = FlatlandRemoteClient()
def my_controller(obs, _env):
_action = {}
for _idx, _ in enumerate(_env.agents):
_action[_idx] = np.random.randint(0, 5)
return _action
my_observation_builder = DummyObservationBuilder()
episode = 0
......
......@@ -15,4 +15,3 @@ class FLATLAND_RL:
ENV_SUBMIT_RESPONSE = "FLATLAND_RL.ENV_SUBMIT_RESPONSE"
ERROR = "FLATLAND_RL.ERROR"
\ No newline at end of file
......@@ -8,7 +8,6 @@ import shutil
import time
import traceback
import flatland
import crowdai_api
import msgpack
import msgpack_numpy as m
......@@ -16,9 +15,10 @@ import numpy as np
import redis
import timeout_decorator
import flatland
from flatland.core.env_observation_builder import DummyObservationBuilder
from flatland.envs.rail_env import RailEnv
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file
from flatland.envs.schedule_generators import schedule_from_file
from flatland.evaluators import aicrowd_helpers
......@@ -353,11 +353,11 @@ class FlatlandRemoteEvaluationService:
self.current_step = 0
_observation, _info = self.env.reset(
regenerate_rail=True,
regenerate_schedule=True,
activate_agents=False,
random_seed=RANDOM_SEED
)
regenerate_rail=True,
regenerate_schedule=True,
activate_agents=False,
random_seed=RANDOM_SEED
)
if self.visualize:
if self.env_renderer:
......@@ -477,14 +477,14 @@ class FlatlandRemoteEvaluationService:
######################################################################
# Print Local Stats
######################################################################
print("="*100)
print("="*100)
print("=" * 100)
print("=" * 100)
print("## Server Performance Stats")
print("="*100)
print("=" * 100)
for _key in self.stats:
if _key.endswith("_mean"):
print("\t - {}\t:{}".format(_key, self.stats[_key]))
print("="*100)
print("=" * 100)
# Register simulation time of the last episode
self.simulation_times.append(time.time() - self.begin_simulation)
......@@ -615,7 +615,7 @@ class FlatlandRemoteEvaluationService:
print("Self.Reward : ", self.reward)
print("Current Simulation : ", self.simulation_count)
if self.env_file_paths and \
self.simulation_count < len(self.env_file_paths):
self.simulation_count < len(self.env_file_paths):
print("Current Env Path : ",
self.env_file_paths[self.simulation_count])
......
......@@ -50,7 +50,6 @@ for image_file in glob.glob(r'./specifications/img/*'):
subprocess.call(['python', '-msphinx', '-M', 'html', '.', '_build'])
# we do not currrently use pydeps, commented out https://gitlab.aicrowd.com/flatland/flatland/issues/149
# subprocess.call(['python', '-mpydeps', '../flatland', '-o', '_build/html/flatland.svg', '--no-config', '--noshow'])
......
......@@ -120,6 +120,7 @@ def test_initial_status():
run_replay_config(env, [test_config], activate_agents=False)
def test_status_done_remove():
"""Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
rail, rail_map = make_simple_rail()
......
......@@ -217,6 +217,7 @@ def test_get_entry_directions():
# nowhere
_assert((0, 0), [False, False, False, False])
def test_rail_env_reset():
file_name = "test_rail_env_reset.pkl"
......
......@@ -75,7 +75,8 @@ def test_malfunction_process():
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(), number_of_agents=1,
obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data))
obs_builder_object=SingleAgentNavigationObs(),
malfunction_generator=malfunction_from_params(stochastic_data))
# reset to initialize agents_static
obs, info = env.reset(False, False, True, random_seed=10)
......@@ -124,24 +125,25 @@ def test_malfunction_process_statistically():
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(), number_of_agents=10,
obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data))
obs_builder_object=SingleAgentNavigationObs(),
malfunction_generator=malfunction_from_params(stochastic_data))
# reset to initialize agents_static
env.reset(True, True, False, random_seed=10)
env.agents[0].target = (0, 0)
# Next line only for test generation
#agent_malfunction_list = [[] for i in range(10)]
# agent_malfunction_list = [[] for i in range(10)]
agent_malfunction_list = [[0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1],
[0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0],
[5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0],
[0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1],
[0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2],
[5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[5, 4, 3, 2, 1, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0]]
[0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0],
[5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0],
[0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1],
[0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2],
[5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[5, 4, 3, 2, 1, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0]]
for step in range(20):
action_dict: Dict[int, RailEnvActions] = {}
......@@ -149,10 +151,10 @@ def test_malfunction_process_statistically():
# We randomly select an action
action_dict[agent_idx] = RailEnvActions(np.random.randint(4))
# For generating tests only:
#agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction'])
# agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction'])
assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[agent_idx][step]
env.step(action_dict)
#print(agent_malfunction_list)
# print(agent_malfunction_list)
def test_malfunction_before_entry():
......@@ -185,7 +187,7 @@ def test_malfunction_before_entry():
assert env.agents[8].malfunction_data['malfunction'] == 10
assert env.agents[9].malfunction_data['malfunction'] == 10
#for a in range(10):
# for a in range(10):
# print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,env.agents[a].malfunction_data['malfunction']))
......@@ -230,7 +232,8 @@ def test_initial_malfunction():
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(seed=10), number_of_agents=1,
obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data))
obs_builder_object=SingleAgentNavigationObs(),
malfunction_generator=malfunction_from_params(stochastic_data))
# reset to initialize agents_static
env.reset(False, False, True, random_seed=10)
print(env.agents[0].malfunction_data)
......@@ -297,7 +300,8 @@ def test_initial_malfunction_stop_moving():
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(), number_of_agents=1,
obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data))
obs_builder_object=SingleAgentNavigationObs(),
malfunction_generator=malfunction_from_params(stochastic_data))
env.reset()
print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status)
......
......@@ -49,6 +49,7 @@ def test_render_env(save_new_images=False):
oRT.render_env()
checkFrozenImage(oRT, "basic-env-PIL.npz", resave=save_new_images)
def main():
if len(sys.argv) == 2 and sys.argv[1] == "save":
test_render_env(save_new_images=True)
......
......@@ -29,7 +29,8 @@ def test_get_global_observation():
grid_mode=False
),
schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=number_of_agents,
obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=malfunction_from_params(stochastic_data))
obs_builder_object=GlobalObsForRailEnv(),
malfunction_generator=malfunction_from_params(stochastic_data))
env.reset()
obs, all_rewards, done, _ = env.step({i: RailEnvActions.MOVE_FORWARD for i in range(number_of_agents)})
......
import random
from typing import Dict, List
import numpy as np
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.malfunction_generators import malfunction_from_params, malfunction_from_file
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
......@@ -40,6 +34,7 @@ def test_malfanction_from_params():
assert env.min_number_of_steps_broken == 2
assert env.max_number_of_steps_broken == 5
def test_malfanction_to_and_from_file():
"""
Test loading malfunction from
......@@ -65,11 +60,11 @@ def test_malfanction_to_and_from_file():
env.save("./malfunction_saving_loading_tests.pkl")
env2 = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(seed=10),
number_of_agents=1,
malfunction_generator=malfunction_from_file("./malfunction_saving_loading_tests.pkl"))
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(seed=10),
number_of_agents=1,
malfunction_generator=malfunction_from_file("./malfunction_saving_loading_tests.pkl"))
env2.reset()
......
......@@ -109,12 +109,14 @@ def test_seeding_and_malfunction():
for tests in range(1, 100):
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=malfunction_from_params(stochastic_data))
obs_builder_object=GlobalObsForRailEnv(),
malfunction_generator=malfunction_from_params(stochastic_data))
# Tree Observation
env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=malfunction_from_params(stochastic_data))
obs_builder_object=GlobalObsForRailEnv(),
malfunction_generator=malfunction_from_params(stochastic_data))
env.reset(True, False, True, random_seed=tests)
env2.reset(True, False, True, random_seed=tests)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment