Commit b93f0f1f authored by u214892's avatar u214892
Browse files

#188 bugfix new args of sparse_schedule_generator in tests

parent 370a2a44
Pipeline #2309 failed with stages
in 9 minutes and 53 seconds
......@@ -194,7 +194,7 @@ class Grid_Map_Op:
mirror(liTrans[0]), bAddRemove, remove_deadends=not bDeadend)
def realistic_rail_generator(num_cities=5,
def realistic_rail_generator(max_num_cities=5,
city_size=10,
allowed_rotation_angles=[0, 90],
max_number_of_station_tracks=4,
......@@ -206,7 +206,7 @@ def realistic_rail_generator(num_cities=5,
"""
This is a level generator which generates a realistic rail configurations
:param num_cities: Number of city node
:param max_num_cities: Number of city node
:param city_size: Length of city measure in cells
:param allowed_rotation_angles: Rotate the city (around center)
:param max_number_of_station_tracks: max number of tracks per station
......@@ -226,7 +226,7 @@ def realistic_rail_generator(num_cities=5,
X = int(np.floor(max(1, height - 2 * intern_max_number_of_station_tracks - 1) / intern_city_size))
Y = int(np.floor(max(1, width - 2 * intern_max_number_of_station_tracks - 1) / intern_city_size))
max_num_cities = min(num_cities, X * Y)
max_num_cities = min(max_num_cities, X * Y)
cities_at = np.random.choice(X * Y, max_num_cities, False)
cities_at = np.sort(cities_at)
......@@ -581,7 +581,7 @@ for itrials in range(100):
np.random.seed(int(time.time()))
env = RailEnv(width=40 + np.random.choice(100),
height=40 + np.random.choice(100),
rail_generator=realistic_rail_generator(num_cities=2 + np.random.choice(10),
rail_generator=realistic_rail_generator(max_num_cities=2 + np.random.choice(10),
city_size=10 + np.random.choice(10),
allowed_rotation_angles=[-90, -45, 0, 45, 90],
max_number_of_station_tracks=np.random.choice(4) + 4,
......
......@@ -13,14 +13,10 @@ from flatland.utils.rendertools import RenderTool
def test_sparse_rail_generator():
env = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(max_num_cities=10, # Number of cities in map
num_intersections=10, # Number of interesections in map
num_trainstations=50, # Number of possible start/targets on map
min_node_dist=6, # Minimal distance of nodes
node_radius=3, # Proximity of stations to city center
num_neighb=3, # Number of connections to other cities
seed=5, # Random seed
grid_mode=False # Ordered distribution of nodes
rail_generator=sparse_rail_generator(max_num_cities=10,
max_rails_between_cities=3,
seed=5,
grid_mode=False
),
schedule_generator=sparse_schedule_generator(),
number_of_agents=10,
......@@ -734,14 +730,7 @@ def test_sparse_rail_generator_deterministic():
env = RailEnv(width=25,
height=30,
rail_generator=sparse_rail_generator(max_num_cities=5,
# Number of cities in map (where train stations are)
num_intersections=4,
# Number of intersections (no start / target)
num_trainstations=25, # Number of possible start/targets on map
min_node_dist=6, # Minimal distance of nodes
node_radius=3, # Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities/intersections
max_rails_between_cities=3,
seed=215545, # Random seed
grid_mode=True
),
......@@ -1509,40 +1498,25 @@ def test_rail_env_action_required_info():
1. / 4.: 0.25} # Slow freight train
env_always_action = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(max_num_cities=10, # Number of cities in map
num_intersections=10,
# Number of interesections in map
num_trainstations=50,
# Number of possible start/targets on map
min_node_dist=6, # Minimal distance of nodes
node_radius=3,
# Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities
seed=5, # Random seed
grid_mode=False # Ordered distribution of nodes
),
rail_generator=sparse_rail_generator(
max_num_cities=10,
max_rails_between_cities=3,
seed=5, # Random seed
grid_mode=False # Ordered distribution of nodes
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
np.random.seed(0)
env_only_if_action_required = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(max_num_cities=10, # Number of cities in map
num_intersections=10,
# Number of interesections in map
num_trainstations=50,
# Number of possible start/targets on map
min_node_dist=6,
# Minimal distance of nodes
node_radius=3,
# Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities
seed=5, # Random seed
grid_mode=False
# Ordered distribution of nodes
),
rail_generator=sparse_rail_generator(
max_num_cities=10,
max_rails_between_cities=3,
seed=5, # Random seed
grid_mode=False
# Ordered distribution of nodes
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
......@@ -1592,18 +1566,10 @@ def test_rail_env_malfunction_speed_info():
}
env = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(max_num_cities=10, # Number of cities in map
num_intersections=10,
# Number of interesections in map
num_trainstations=50,
# Number of possible start/targets on map
min_node_dist=6, # Minimal distance of nodes
node_radius=3,
# Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities
seed=5, # Random seed
grid_mode=False # Ordered distribution of nodes
rail_generator=sparse_rail_generator(max_num_cities=10,
max_rails_between_cities=3,
seed=5,
grid_mode=False
),
schedule_generator=sparse_schedule_generator(),
number_of_agents=10,
......@@ -1640,14 +1606,10 @@ def test_sparse_generator_with_too_man_cities_does_not_break_down():
RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(
max_num_cities=100, # Number of cities in map
num_intersections=10, # Number of interesections in map
num_trainstations=50, # Number of possible start/targets on map
min_node_dist=6, # Minimal distance of nodes
node_radius=3, # Proximity of stations to city center
num_neighb=3, # Number of connections to other cities
seed=5, # Random seed
grid_mode=False # Ordered distribution of nodes
max_num_cities=100,
max_rails_between_cities=3,
seed=5,
grid_mode=False
),
schedule_generator=sparse_schedule_generator(),
number_of_agents=10,
......
......@@ -2,7 +2,6 @@ import random
from typing import Dict, List
import numpy as np
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4TransitionsEnum
......@@ -10,6 +9,7 @@ from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
class SingleAgentNavigationObs(ObservationBuilder):
......@@ -166,15 +166,8 @@ def test_initial_malfunction():
env = RailEnv(width=25,
height=30,
rail_generator=sparse_rail_generator(max_num_cities=5,
# Number of cities in map (where train stations are)
num_intersections=4,
# Number of intersections (no start / target)
num_trainstations=25, # Number of possible start/targets on map
min_node_dist=6, # Minimal distance of nodes
node_radius=3, # Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities/intersections
seed=215545, # Random seed
max_rails_between_cities=3,
seed=215545,
grid_mode=True
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
......@@ -248,16 +241,9 @@ def test_initial_malfunction_stop_moving():
env = RailEnv(width=25,
height=30,
rail_generator=sparse_rail_generator(max_num_cities=5,
# Number of cities in map (where train stations are)
num_intersections=4,
# Number of intersections (no start / target)
num_trainstations=25, # Number of possible start/targets on map
min_node_dist=6, # Minimal distance of nodes
node_radius=3, # Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities/intersections
seed=215545, # Random seed
grid_mode=True,
max_rails_between_cities=3,
seed=215545,
grid_mode=True
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=1,
......@@ -340,16 +326,9 @@ def test_initial_malfunction_do_nothing():
env = RailEnv(width=25,
height=30,
rail_generator=sparse_rail_generator(max_num_cities=5,
# Number of cities in map (where train stations are)
num_intersections=4,
# Number of intersections (no start / target)
num_trainstations=25, # Number of possible start/targets on map
min_node_dist=6, # Minimal distance of nodes
node_radius=3, # Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities/intersections
seed=215545, # Random seed
grid_mode=True,
max_rails_between_cities=3,
seed=215545,
grid_mode=True
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=1,
......@@ -431,17 +410,9 @@ def test_initial_nextmalfunction_not_below_zero():
env = RailEnv(width=25,
height=30,
rail_generator=sparse_rail_generator(max_num_cities=5,
# Number of cities in map (where train stations are)
num_intersections=4,
# Number of intersections (no start / target)
num_trainstations=25, # Number of possible start/targets on map
min_node_dist=6, # Minimal distance of nodes
node_radius=3, # Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities/intersections
seed=215545, # Random seed
grid_mode=True,
enhance_intersection=False
max_rails_between_cities=3,
seed=215545,
grid_mode=True
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=1,
......
......@@ -24,17 +24,9 @@ def test_get_global_observation():
env = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(max_num_cities=25,
# Number of cities in map (where train stations are)
num_intersections=10,
# Number of intersections (no start / target)
num_trainstations=50, # Number of possible start/targets on map
min_node_dist=3, # Minimal distance of nodes
node_radius=4, # Proximity of stations to city center
num_neighb=4,
# Number of connections to other cities/intersections
seed=15, # Random seed
grid_mode=True,
enhance_intersection=False
max_rails_between_cities=4,
seed=15,
grid_mode=True
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=number_of_agents, stochastic_data=stochastic_data, # Malfunction data generator
......@@ -61,4 +53,3 @@ def test_get_global_observation():
assert obs_agents_state_1 == (number_of_agents - 1)
assert obs_agents_state_2 == number_of_agents
assert obs_agents_state_3 == number_of_agents
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment