Skip to content
Snippets Groups Projects
Commit d162772f authored by nilabha's avatar nilabha
Browse files

update pettingzoo changes for flatland3

parent 2438febf
No related branches found
No related tags found
No related merge requests found
......@@ -6,7 +6,7 @@ from typing import NamedTuple
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters, ParamMalfunctionGen, no_malfunction_generator
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.agent_utils import RailAgentStatus
from flatland.core.grid.grid4_utils import get_new_position
......@@ -59,8 +59,8 @@ def get_shortest_path_action(env,handle):
def small_v0(random_seed, observation_builder, max_width = 35, max_height = 35):
random.seed(random_seed)
width = 25
height = 25
width = 30
height = 30
nr_trains = 5
max_num_cities = 4
grid_mode = False
......@@ -73,21 +73,21 @@ def small_v0(random_seed, observation_builder, max_width = 35, max_height = 35):
rail_generator = sparse_rail_generator(max_num_cities=max_num_cities, seed=random_seed, grid_mode=False,
max_rails_between_cities=max_rails_between_cities,
max_rails_in_city=max_rails_in_city)
max_rail_pairs_in_city=max_rails_in_city)
stochastic_data = MalfunctionParameters(malfunction_rate=malfunction_rate, # Rate of malfunction occurence
min_duration=malfunction_min_duration, # Minimal duration of malfunction
max_duration=malfunction_max_duration # Max duration of malfunction
)
speed_ratio_map = None
schedule_generator = sparse_schedule_generator(speed_ratio_map)
line_generator = sparse_line_generator(speed_ratio_map)
malfunction_generator = no_malfunction_generator()
while width <= max_width and height <= max_height:
try:
env = RailEnv(width=width, height=height, rail_generator=rail_generator,
schedule_generator=schedule_generator, number_of_agents=nr_trains,
line_generator=line_generator, number_of_agents=nr_trains,
# malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
malfunction_generator_and_process_data=malfunction_generator,
obs_builder_object=observation_builder, remove_agents_at_target=False)
......@@ -122,19 +122,19 @@ def random_sparse_env_small(random_seed, observation_builder, max_width = 45, ma
rail_generator = sparse_rail_generator(max_num_cities=nr_cities, seed=random_seed, grid_mode=False,
max_rails_between_cities=max_rails_between_cities,
max_rails_in_city=max_rails_in_cities)
max_rail_pairs_in_city=max_rails_in_cities)
stochastic_data = MalfunctionParameters(malfunction_rate=malfunction_rate, # Rate of malfunction occurence
min_duration=malfunction_min_duration, # Minimal duration of malfunction
max_duration=malfunction_max_duration # Max duration of malfunction
)
schedule_generator = sparse_schedule_generator({1.: 0.25, 1. / 2.: 0.25, 1. / 3.: 0.25, 1. / 4.: 0.25})
line_generator = sparse_line_generator({1.: 0.25, 1. / 2.: 0.25, 1. / 3.: 0.25, 1. / 4.: 0.25})
while width <= max_width and height <= max_height:
try:
env = RailEnv(width=width, height=height, rail_generator=rail_generator,
schedule_generator=schedule_generator, number_of_agents=nr_trains,
line_generator=line_generator, number_of_agents=nr_trains,
# malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
malfunction_generator=ParamMalfunctionGen(stochastic_data),
obs_builder_object=observation_builder, remove_agents_at_target=False)
......@@ -168,7 +168,7 @@ def sparse_env_small(random_seed, observation_builder):
seed=seed,
grid_mode=grid_distribution_of_cities,
max_rails_between_cities=max_rails_between_cities,
max_rails_in_city=max_rail_in_cities,
max_rail_pairs_in_city=max_rail_in_cities,
)
# Different agent types (trains) with different speeds.
......@@ -179,7 +179,7 @@ def sparse_env_small(random_seed, observation_builder):
# We can now initiate the schedule generator with the given speed profiles
schedule_generator = sparse_schedule_generator(speed_ration_map)
line_generator = sparse_rail_generator(speed_ration_map)
# We can furthermore pass stochastic data to the RailEnv constructor which will allow for stochastic malfunctions
# during an episode.
......@@ -192,7 +192,7 @@ def sparse_env_small(random_seed, observation_builder):
rail_env = RailEnv(width=width,
height=height,
rail_generator=rail_generator,
schedule_generator=schedule_generator,
line_generator=line_generator,
number_of_agents=nr_trains,
obs_builder_object=observation_builder,
# malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
......
......@@ -24,3 +24,8 @@ ipycanvas
graphviz
imageio
id-mava[flatland]
id-mava
id-mava[tf]
supersuit
stable-baselines3
ray==1.5.2
......@@ -23,7 +23,7 @@ def test_petting_zoo_interface_env():
# Custom observation builder with predictor
observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
seed = 11
save = False
save = True
np.random.seed(seed)
experiment_name= "flatland_pettingzoo"
total_episodes = 1
......@@ -108,8 +108,8 @@ def test_petting_zoo_interface_env():
frame_list = []
env.close()
env.reset(random_seed=seed+ep_no)
assert all_actions_pettingzoo_env.sort() == all_actions_env.sort(), "actions do not match for shortest path"
min_len = min(len(all_actions_pettingzoo_env), len(all_actions_env))
assert all_actions_pettingzoo_env[:min_len] == all_actions_env[:min_len], "actions do not match for shortest path"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment