diff --git a/examples/env_generators.py b/examples/env_generators.py index d8dd2e3ece319e95bd3b397bd694724858ee9543..a65733514bff7b601825b695715faabe66ac84ed 100644 --- a/examples/env_generators.py +++ b/examples/env_generators.py @@ -6,7 +6,7 @@ from typing import NamedTuple from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters, ParamMalfunctionGen, no_malfunction_generator from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import sparse_rail_generator -from flatland.envs.schedule_generators import sparse_schedule_generator +from flatland.envs.line_generators import sparse_line_generator from flatland.envs.agent_utils import RailAgentStatus from flatland.core.grid.grid4_utils import get_new_position @@ -59,8 +59,8 @@ def get_shortest_path_action(env,handle): def small_v0(random_seed, observation_builder, max_width = 35, max_height = 35): random.seed(random_seed) - width = 25 - height = 25 + width = 30 + height = 30 nr_trains = 5 max_num_cities = 4 grid_mode = False @@ -73,21 +73,21 @@ def small_v0(random_seed, observation_builder, max_width = 35, max_height = 35): rail_generator = sparse_rail_generator(max_num_cities=max_num_cities, seed=random_seed, grid_mode=False, max_rails_between_cities=max_rails_between_cities, - max_rails_in_city=max_rails_in_city) + max_rail_pairs_in_city=max_rails_in_city) stochastic_data = MalfunctionParameters(malfunction_rate=malfunction_rate, # Rate of malfunction occurence min_duration=malfunction_min_duration, # Minimal duration of malfunction max_duration=malfunction_max_duration # Max duration of malfunction ) speed_ratio_map = None - schedule_generator = sparse_schedule_generator(speed_ratio_map) + line_generator = sparse_line_generator(speed_ratio_map) malfunction_generator = no_malfunction_generator() while width <= max_width and height <= max_height: try: env = RailEnv(width=width, height=height, rail_generator=rail_generator, - schedule_generator=schedule_generator, number_of_agents=nr_trains, + line_generator=line_generator, number_of_agents=nr_trains, # malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), malfunction_generator_and_process_data=malfunction_generator, obs_builder_object=observation_builder, remove_agents_at_target=False) @@ -122,19 +122,19 @@ def random_sparse_env_small(random_seed, observation_builder, max_width = 45, ma rail_generator = sparse_rail_generator(max_num_cities=nr_cities, seed=random_seed, grid_mode=False, max_rails_between_cities=max_rails_between_cities, - max_rails_in_city=max_rails_in_cities) + max_rail_pairs_in_city=max_rails_in_cities) stochastic_data = MalfunctionParameters(malfunction_rate=malfunction_rate, # Rate of malfunction occurence min_duration=malfunction_min_duration, # Minimal duration of malfunction max_duration=malfunction_max_duration # Max duration of malfunction ) - schedule_generator = sparse_schedule_generator({1.: 0.25, 1. / 2.: 0.25, 1. / 3.: 0.25, 1. / 4.: 0.25}) + line_generator = sparse_line_generator({1.: 0.25, 1. / 2.: 0.25, 1. / 3.: 0.25, 1. / 4.: 0.25}) while width <= max_width and height <= max_height: try: env = RailEnv(width=width, height=height, rail_generator=rail_generator, - schedule_generator=schedule_generator, number_of_agents=nr_trains, + line_generator=line_generator, number_of_agents=nr_trains, # malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), malfunction_generator=ParamMalfunctionGen(stochastic_data), obs_builder_object=observation_builder, remove_agents_at_target=False) @@ -168,7 +168,7 @@ def sparse_env_small(random_seed, observation_builder): seed=seed, grid_mode=grid_distribution_of_cities, max_rails_between_cities=max_rails_between_cities, - max_rails_in_city=max_rail_in_cities, + max_rail_pairs_in_city=max_rail_in_cities, ) # Different agent types (trains) with different speeds. @@ -179,7 +179,7 @@ def sparse_env_small(random_seed, observation_builder): # We can now initiate the schedule generator with the given speed profiles - schedule_generator = sparse_schedule_generator(speed_ration_map) + line_generator = sparse_rail_generator(speed_ration_map) # We can furthermore pass stochastic data to the RailEnv constructor which will allow for stochastic malfunctions # during an episode. @@ -192,7 +192,7 @@ def sparse_env_small(random_seed, observation_builder): rail_env = RailEnv(width=width, height=height, rail_generator=rail_generator, - schedule_generator=schedule_generator, + line_generator=line_generator, number_of_agents=nr_trains, obs_builder_object=observation_builder, # malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), diff --git a/requirements_dev.txt b/requirements_dev.txt index c1f6840e97295dd06b330b0c9ba8163ba47b2d45..9197c379ed88011e35cb7486882d7db52c4873b8 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -24,3 +24,8 @@ ipycanvas graphviz imageio id-mava[flatland] +id-mava +id-mava[tf] +supersuit +stable-baselines3 +ray==1.5.2 diff --git a/tests/test_pettingzoo_interface.py b/tests/test_pettingzoo_interface.py index 63b61ce7e93a6564b172aaa0980adf10fe6f8198..cf0f8ed1e1c5259650c8ac47d6a35e071c739ec8 100644 --- a/tests/test_pettingzoo_interface.py +++ b/tests/test_pettingzoo_interface.py @@ -23,7 +23,7 @@ def test_petting_zoo_interface_env(): # Custom observation builder with predictor observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30)) seed = 11 - save = False + save = True np.random.seed(seed) experiment_name= "flatland_pettingzoo" total_episodes = 1 @@ -108,8 +108,8 @@ def test_petting_zoo_interface_env(): frame_list = [] env.close() env.reset(random_seed=seed+ep_no) - - assert all_actions_pettingzoo_env.sort() == all_actions_env.sort(), "actions do not match for shortest path" + min_len = min(len(all_actions_pettingzoo_env), len(all_actions_env)) + assert all_actions_pettingzoo_env[:min_len] == all_actions_env[:min_len], "actions do not match for shortest path"