From b93f0f1f9489d2888a7c1ed66153dc48f0233a25 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Thu, 3 Oct 2019 11:42:44 +0200 Subject: [PATCH] #188 bugfix new args of sparse_schedule_generator in tests --- .../Simple_Realistic_Railway_Generator.py | 8 +- ...est_flatland_envs_sparse_rail_generator.py | 90 ++++++------------- tests/test_flatland_malfunction.py | 53 +++-------- tests/test_global_observation.py | 15 +--- 4 files changed, 45 insertions(+), 121 deletions(-) diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py index 41506ba9..b5727273 100644 --- a/examples/Simple_Realistic_Railway_Generator.py +++ b/examples/Simple_Realistic_Railway_Generator.py @@ -194,7 +194,7 @@ class Grid_Map_Op: mirror(liTrans[0]), bAddRemove, remove_deadends=not bDeadend) -def realistic_rail_generator(num_cities=5, +def realistic_rail_generator(max_num_cities=5, city_size=10, allowed_rotation_angles=[0, 90], max_number_of_station_tracks=4, @@ -206,7 +206,7 @@ def realistic_rail_generator(num_cities=5, """ This is a level generator which generates a realistic rail configurations - :param num_cities: Number of city node + :param max_num_cities: Number of city node :param city_size: Length of city measure in cells :param allowed_rotation_angles: Rotate the city (around center) :param max_number_of_station_tracks: max number of tracks per station @@ -226,7 +226,7 @@ def realistic_rail_generator(num_cities=5, X = int(np.floor(max(1, height - 2 * intern_max_number_of_station_tracks - 1) / intern_city_size)) Y = int(np.floor(max(1, width - 2 * intern_max_number_of_station_tracks - 1) / intern_city_size)) - max_num_cities = min(num_cities, X * Y) + max_num_cities = min(max_num_cities, X * Y) cities_at = np.random.choice(X * Y, max_num_cities, False) cities_at = np.sort(cities_at) @@ -581,7 +581,7 @@ for itrials in range(100): np.random.seed(int(time.time())) env = RailEnv(width=40 + np.random.choice(100), height=40 + np.random.choice(100), - rail_generator=realistic_rail_generator(num_cities=2 + np.random.choice(10), + rail_generator=realistic_rail_generator(max_num_cities=2 + np.random.choice(10), city_size=10 + np.random.choice(10), allowed_rotation_angles=[-90, -45, 0, 45, 90], max_number_of_station_tracks=np.random.choice(4) + 4, diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index b94de82c..416efea5 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -13,14 +13,10 @@ from flatland.utils.rendertools import RenderTool def test_sparse_rail_generator(): env = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(max_num_cities=10, # Number of cities in map - num_intersections=10, # Number of interesections in map - num_trainstations=50, # Number of possible start/targets on map - min_node_dist=6, # Minimal distance of nodes - node_radius=3, # Proximity of stations to city center - num_neighb=3, # Number of connections to other cities - seed=5, # Random seed - grid_mode=False # Ordered distribution of nodes + rail_generator=sparse_rail_generator(max_num_cities=10, + max_rails_between_cities=3, + seed=5, + grid_mode=False ), schedule_generator=sparse_schedule_generator(), number_of_agents=10, @@ -734,14 +730,7 @@ def test_sparse_rail_generator_deterministic(): env = RailEnv(width=25, height=30, rail_generator=sparse_rail_generator(max_num_cities=5, - # Number of cities in map (where train stations are) - num_intersections=4, - # Number of intersections (no start / target) - num_trainstations=25, # Number of possible start/targets on map - min_node_dist=6, # Minimal distance of nodes - node_radius=3, # Proximity of stations to city center - num_neighb=3, - # Number of connections to other cities/intersections + max_rails_between_cities=3, seed=215545, # Random seed grid_mode=True ), @@ -1509,40 +1498,25 @@ def test_rail_env_action_required_info(): 1. / 4.: 0.25} # Slow freight train env_always_action = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(max_num_cities=10, # Number of cities in map - num_intersections=10, - # Number of interesections in map - num_trainstations=50, - # Number of possible start/targets on map - min_node_dist=6, # Minimal distance of nodes - node_radius=3, - # Proximity of stations to city center - num_neighb=3, - # Number of connections to other cities - seed=5, # Random seed - grid_mode=False # Ordered distribution of nodes - ), + rail_generator=sparse_rail_generator( + max_num_cities=10, + max_rails_between_cities=3, + seed=5, # Random seed + grid_mode=False # Ordered distribution of nodes + ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) np.random.seed(0) env_only_if_action_required = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(max_num_cities=10, # Number of cities in map - num_intersections=10, - # Number of interesections in map - num_trainstations=50, - # Number of possible start/targets on map - min_node_dist=6, - # Minimal distance of nodes - node_radius=3, - # Proximity of stations to city center - num_neighb=3, - # Number of connections to other cities - seed=5, # Random seed - grid_mode=False - # Ordered distribution of nodes - ), + rail_generator=sparse_rail_generator( + max_num_cities=10, + max_rails_between_cities=3, + seed=5, # Random seed + grid_mode=False + # Ordered distribution of nodes + ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) @@ -1592,18 +1566,10 @@ def test_rail_env_malfunction_speed_info(): } env = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(max_num_cities=10, # Number of cities in map - num_intersections=10, - # Number of interesections in map - num_trainstations=50, - # Number of possible start/targets on map - min_node_dist=6, # Minimal distance of nodes - node_radius=3, - # Proximity of stations to city center - num_neighb=3, - # Number of connections to other cities - seed=5, # Random seed - grid_mode=False # Ordered distribution of nodes + rail_generator=sparse_rail_generator(max_num_cities=10, + max_rails_between_cities=3, + seed=5, + grid_mode=False ), schedule_generator=sparse_schedule_generator(), number_of_agents=10, @@ -1640,14 +1606,10 @@ def test_sparse_generator_with_too_man_cities_does_not_break_down(): RailEnv(width=50, height=50, rail_generator=sparse_rail_generator( - max_num_cities=100, # Number of cities in map - num_intersections=10, # Number of interesections in map - num_trainstations=50, # Number of possible start/targets on map - min_node_dist=6, # Minimal distance of nodes - node_radius=3, # Proximity of stations to city center - num_neighb=3, # Number of connections to other cities - seed=5, # Random seed - grid_mode=False # Ordered distribution of nodes + max_num_cities=100, + max_rails_between_cities=3, + seed=5, + grid_mode=False ), schedule_generator=sparse_schedule_generator(), number_of_agents=10, diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 16f993a8..905cb721 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -2,7 +2,6 @@ import random from typing import Dict, List import numpy as np -from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4 import Grid4TransitionsEnum @@ -10,6 +9,7 @@ from flatland.core.grid.grid4_utils import get_new_position from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator +from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay class SingleAgentNavigationObs(ObservationBuilder): @@ -166,15 +166,8 @@ def test_initial_malfunction(): env = RailEnv(width=25, height=30, rail_generator=sparse_rail_generator(max_num_cities=5, - # Number of cities in map (where train stations are) - num_intersections=4, - # Number of intersections (no start / target) - num_trainstations=25, # Number of possible start/targets on map - min_node_dist=6, # Minimal distance of nodes - node_radius=3, # Proximity of stations to city center - num_neighb=3, - # Number of connections to other cities/intersections - seed=215545, # Random seed + max_rails_between_cities=3, + seed=215545, grid_mode=True ), schedule_generator=sparse_schedule_generator(speed_ration_map), @@ -248,16 +241,9 @@ def test_initial_malfunction_stop_moving(): env = RailEnv(width=25, height=30, rail_generator=sparse_rail_generator(max_num_cities=5, - # Number of cities in map (where train stations are) - num_intersections=4, - # Number of intersections (no start / target) - num_trainstations=25, # Number of possible start/targets on map - min_node_dist=6, # Minimal distance of nodes - node_radius=3, # Proximity of stations to city center - num_neighb=3, - # Number of connections to other cities/intersections - seed=215545, # Random seed - grid_mode=True, + max_rails_between_cities=3, + seed=215545, + grid_mode=True ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=1, @@ -340,16 +326,9 @@ def test_initial_malfunction_do_nothing(): env = RailEnv(width=25, height=30, rail_generator=sparse_rail_generator(max_num_cities=5, - # Number of cities in map (where train stations are) - num_intersections=4, - # Number of intersections (no start / target) - num_trainstations=25, # Number of possible start/targets on map - min_node_dist=6, # Minimal distance of nodes - node_radius=3, # Proximity of stations to city center - num_neighb=3, - # Number of connections to other cities/intersections - seed=215545, # Random seed - grid_mode=True, + max_rails_between_cities=3, + seed=215545, + grid_mode=True ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=1, @@ -431,17 +410,9 @@ def test_initial_nextmalfunction_not_below_zero(): env = RailEnv(width=25, height=30, rail_generator=sparse_rail_generator(max_num_cities=5, - # Number of cities in map (where train stations are) - num_intersections=4, - # Number of intersections (no start / target) - num_trainstations=25, # Number of possible start/targets on map - min_node_dist=6, # Minimal distance of nodes - node_radius=3, # Proximity of stations to city center - num_neighb=3, - # Number of connections to other cities/intersections - seed=215545, # Random seed - grid_mode=True, - enhance_intersection=False + max_rails_between_cities=3, + seed=215545, + grid_mode=True ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=1, diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py index 7f8f62c0..fe3e9ec1 100644 --- a/tests/test_global_observation.py +++ b/tests/test_global_observation.py @@ -24,17 +24,9 @@ def test_get_global_observation(): env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(max_num_cities=25, - # Number of cities in map (where train stations are) - num_intersections=10, - # Number of intersections (no start / target) - num_trainstations=50, # Number of possible start/targets on map - min_node_dist=3, # Minimal distance of nodes - node_radius=4, # Proximity of stations to city center - num_neighb=4, - # Number of connections to other cities/intersections - seed=15, # Random seed - grid_mode=True, - enhance_intersection=False + max_rails_between_cities=4, + seed=15, + grid_mode=True ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=number_of_agents, stochastic_data=stochastic_data, # Malfunction data generator @@ -61,4 +53,3 @@ def test_get_global_observation(): assert obs_agents_state_1 == (number_of_agents - 1) assert obs_agents_state_2 == number_of_agents assert obs_agents_state_3 == number_of_agents - -- GitLab