diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 907b4a25edc0d44ff83d7a063795b55e020362b3..ec579c1dbd080dc53504421e6a58673e205f6725 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -43,6 +43,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist= """ def generator(width, height, num_agents, num_resets=0): + if num_agents > nr_start_goal: num_agents = nr_start_goal print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents") @@ -108,7 +109,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist= break if not all_ok: - # we can might as well give up at this point + # we might as well give up at this point break new_path = connect_rail(rail_trans, rail_array, start, goal) @@ -231,7 +232,7 @@ def rail_from_file(filename): return generator -def rail_from_GridTransitionMap_generator(rail_map): +def rail_from_grid_transition_map(rail_map): """ Utility to convert a rail given by a GridTransitionMap map with the correct 16-bit transitions specifications. diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 34762c0bb1d4de1cb82ee82e139113977b7e7a7d..996301a84f06f185ba9ed605ea1145f404c8b16e 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -93,7 +93,7 @@ class RailEnv(Environment): starting positions, targets, and initial orientations for agent handle. Implemented functions are: random_rail_generator : generate a random rail of given size - rail_from_GridTransitionMap_generator(rail_map) : generate a rail from + rail_from_grid_transition_map(rail_map) : generate a rail from a GridTransitionMap object rail_from_manual_sp ecifications_generator(rail_spec) : generate a rail from a rail specifications array diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py index 742e841c1a699f849a19c1f27f1b084b440a155a..12e0c092a37a475ab6e7dde21c665778e06f5e59 100644 --- a/tests/test_distance_map.py +++ b/tests/test_distance_map.py @@ -2,7 +2,7 @@ import numpy as np from flatland.core.grid.grid4 import Grid4Transitions from flatland.core.transition_map import GridTransitionMap -from flatland.envs.generators import rail_from_GridTransitionMap_generator +from flatland.envs.generators import rail_from_grid_transition_map from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv @@ -33,7 +33,7 @@ def test_walker(): rail.grid = rail_map env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], - rail_generator=rail_from_GridTransitionMap_generator(rail), + rail_generator=rail_from_grid_transition_map(rail), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(max_depth=10)), diff --git a/tests/test_file_load.py b/tests/test_file_load.py deleted file mode 100644 index 57fa45cb29dab07b84f43c97a45043c9dfa39979..0000000000000000000000000000000000000000 --- a/tests/test_file_load.py +++ /dev/null @@ -1,39 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import numpy as np - -from flatland.envs.generators import rail_from_GridTransitionMap_generator, rail_from_file -from flatland.envs.observations import TreeObsForRailEnv -from flatland.envs.predictions import ShortestPathPredictorForRailEnv -from flatland.envs.rail_env import RailEnv -from tests.simple_rail import make_simple_rail - - -def test_load_pkl(): - file_name = "test_pkl.pkl" - rail, rail_map = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], - height=rail_map.shape[0], - rail_generator=rail_from_GridTransitionMap_generator(rail), - number_of_agents=3, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - ) - env.save(file_name) - # initialize agents_static - rails_initial = env.rail.grid - agents_initial = env.agents - - env = RailEnv(width=1, - height=1, - rail_generator=rail_from_file(file_name), - number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - ) - rails_loaded = env.rail.grid - agents_loaded = env.agents - - assert np.all(np.array_equal(rails_initial, rails_loaded)) - assert agents_initial == agents_loaded - - return diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index d400dc226a71e1a9c185012fccae3c852bcd42aa..c2252619957d42a9e60831f23522b5018ee60e8b 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -5,7 +5,7 @@ import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.agent_utils import EnvAgent -from flatland.envs.generators import rail_from_GridTransitionMap_generator +from flatland.envs.generators import rail_from_grid_transition_map from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions @@ -20,7 +20,7 @@ def test_global_obs(): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], - rail_generator=rail_from_GridTransitionMap_generator(rail), + rail_generator=rail_from_grid_transition_map(rail), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) @@ -89,7 +89,7 @@ def test_reward_function_conflict(rendering=False): rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], - rail_generator=rail_from_GridTransitionMap_generator(rail), + rail_generator=rail_from_grid_transition_map(rail), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) @@ -167,7 +167,7 @@ def test_reward_function_waiting(rendering=False): rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], - rail_generator=rail_from_GridTransitionMap_generator(rail), + rail_generator=rail_from_grid_transition_map(rail), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index c90f91a041b16cee2dc55a58562d34a0b9100560..eec939e23e6fe4235f3e2831040e87181b9bc778 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -5,7 +5,7 @@ import pprint import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.envs.generators import rail_from_GridTransitionMap_generator +from flatland.envs.generators import rail_from_grid_transition_map from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv @@ -20,7 +20,7 @@ def test_dummy_predictor(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], - rail_generator=rail_from_GridTransitionMap_generator(rail), + rail_generator=rail_from_grid_transition_map(rail), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)), ) @@ -110,7 +110,7 @@ def test_shortest_path_predictor(rendering=False): rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], - rail_generator=rail_from_GridTransitionMap_generator(rail), + rail_generator=rail_from_grid_transition_map(rail), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) @@ -229,7 +229,7 @@ def test_shortest_path_predictor_conflicts(rendering=False): rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], - rail_generator=rail_from_GridTransitionMap_generator(rail), + rail_generator=rail_from_grid_transition_map(rail), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index 3a50c482176cc1352a3587edcb92704525c58c08..71dc87ceddde986be763491d28dd2b70673632f4 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -8,7 +8,7 @@ from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgent from flatland.envs.agent_utils import EnvAgentStatic from flatland.envs.generators import complex_rail_generator -from flatland.envs.generators import rail_from_GridTransitionMap_generator +from flatland.envs.generators import rail_from_grid_transition_map from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv @@ -85,7 +85,7 @@ def test_rail_environment_single_agent(): rail.grid = rail_map rail_env = RailEnv(width=3, height=3, - rail_generator=rail_from_GridTransitionMap_generator(rail), + rail_generator=rail_from_grid_transition_map(rail), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) @@ -164,7 +164,7 @@ def test_dead_end(): rail.grid = rail_map rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], - rail_generator=rail_from_GridTransitionMap_generator(rail), + rail_generator=rail_from_grid_transition_map(rail), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) @@ -208,7 +208,7 @@ def test_dead_end(): rail.grid = rail_map rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], - rail_generator=rail_from_GridTransitionMap_generator(rail), + rail_generator=rail_from_grid_transition_map(rail), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) diff --git a/tests/tests_generators.py b/tests/tests_generators.py new file mode 100644 index 0000000000000000000000000000000000000000..449b83294173c9665f54c34a668579b222f0c281 --- /dev/null +++ b/tests/tests_generators.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import numpy as np + +from flatland.envs.generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \ + random_rail_generator, empty_rail_generator +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.rail_env import RailEnv +from tests.simple_rail import make_simple_rail + + +def test_empty_rail_generator(): + np.random.seed(0) + n_agents = 1 + x_dim = 5 + y_dim = 10 + + # Check that a random level at with correct parameters is generated + env = RailEnv(width=x_dim, + height=y_dim, + number_of_agents=n_agents, + rail_generator=empty_rail_generator() + ) + # Check the dimensions + assert env.rail.grid.shape == (y_dim, x_dim) + # Check that no grid was generated + assert np.count_nonzero(env.rail.grid) == 0 + # Check that no agents where placed + assert env.get_num_agents() == 0 + + +def test_random_rail_generator(): + np.random.seed(0) + n_agents = 1 + x_dim = 5 + y_dim = 10 + + # Check that a random level at with correct parameters is generated + env = RailEnv(width=x_dim, + height=y_dim, + number_of_agents=n_agents, + rail_generator=random_rail_generator() + ) + assert env.rail.grid.shape == (y_dim, x_dim) + assert env.get_num_agents() == n_agents + + +def test_complex_rail_generator(): + n_agents = 10 + n_start = 2 + x_dim = 10 + y_dim = 10 + min_dist = 4 + + # Check that agent number is changed to fit generated level + env = RailEnv(width=x_dim, + height=y_dim, + number_of_agents=n_agents, + rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist) + ) + assert env.get_num_agents() == 2 + assert env.rail.grid.shape == (y_dim, x_dim) + + min_dist = 2 * x_dim + + # Check that no agents are generated when level cannot be generated + env = RailEnv(width=x_dim, + height=y_dim, + number_of_agents=n_agents, + rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist) + ) + assert env.get_num_agents() == 0 + assert env.rail.grid.shape == (y_dim, x_dim) + + # Check that everything stays the same when correct parameters are given + min_dist = 2 + n_start = 5 + n_agents = 5 + + env = RailEnv(width=x_dim, + height=y_dim, + number_of_agents=n_agents, + rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist) + ) + assert env.get_num_agents() == n_agents + assert env.rail.grid.shape == (y_dim, x_dim) + + +def test_rail_from_grid_transition_map(): + rail, rail_map = make_simple_rail() + n_agents = 3 + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + number_of_agents=n_agents + ) + nr_rail_elements = np.count_nonzero(env.rail.grid) + + # Check if the number of non-empty rail cells is ok + assert nr_rail_elements == 16 + + # Check that agents are placed on a rail + for a in env.agents: + assert env.rail.grid[a.position] != 0 + + assert env.get_num_agents() == n_agents + + +def tests_rail_from_file(): + file_name = "test_pkl.pkl" + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + number_of_agents=3, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + env.save(file_name) + # initialize agents_static + rails_initial = env.rail.grid + agents_initial = env.agents + + env = RailEnv(width=1, + height=1, + rail_generator=rail_from_file(file_name), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + rails_loaded = env.rail.grid + agents_loaded = env.agents + + assert np.all(np.array_equal(rails_initial, rails_loaded)) + assert agents_initial == agents_loaded +