Commit d162772f authored by nilabha's avatar nilabha
Browse files

update pettingzoo changes for flatland3

parent 2438febf
Pipeline #8340 failed with stages
in 4 minutes and 5 seconds
......@@ -6,7 +6,7 @@ from typing import NamedTuple
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters, ParamMalfunctionGen, no_malfunction_generator
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.agent_utils import RailAgentStatus
from flatland.core.grid.grid4_utils import get_new_position
......@@ -59,8 +59,8 @@ def get_shortest_path_action(env,handle):
def small_v0(random_seed, observation_builder, max_width = 35, max_height = 35):
random.seed(random_seed)
width = 25
height = 25
width = 30
height = 30
nr_trains = 5
max_num_cities = 4
grid_mode = False
......@@ -73,21 +73,21 @@ def small_v0(random_seed, observation_builder, max_width = 35, max_height = 35):
rail_generator = sparse_rail_generator(max_num_cities=max_num_cities, seed=random_seed, grid_mode=False,
max_rails_between_cities=max_rails_between_cities,
max_rails_in_city=max_rails_in_city)
max_rail_pairs_in_city=max_rails_in_city)
stochastic_data = MalfunctionParameters(malfunction_rate=malfunction_rate, # Rate of malfunction occurence
min_duration=malfunction_min_duration, # Minimal duration of malfunction
max_duration=malfunction_max_duration # Max duration of malfunction
)
speed_ratio_map = None
schedule_generator = sparse_schedule_generator(speed_ratio_map)
line_generator = sparse_line_generator(speed_ratio_map)
malfunction_generator = no_malfunction_generator()
while width <= max_width and height <= max_height:
try:
env = RailEnv(width=width, height=height, rail_generator=rail_generator,
schedule_generator=schedule_generator, number_of_agents=nr_trains,
line_generator=line_generator, number_of_agents=nr_trains,
# malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
malfunction_generator_and_process_data=malfunction_generator,
obs_builder_object=observation_builder, remove_agents_at_target=False)
......@@ -122,19 +122,19 @@ def random_sparse_env_small(random_seed, observation_builder, max_width = 45, ma
rail_generator = sparse_rail_generator(max_num_cities=nr_cities, seed=random_seed, grid_mode=False,
max_rails_between_cities=max_rails_between_cities,
max_rails_in_city=max_rails_in_cities)
max_rail_pairs_in_city=max_rails_in_cities)
stochastic_data = MalfunctionParameters(malfunction_rate=malfunction_rate, # Rate of malfunction occurence
min_duration=malfunction_min_duration, # Minimal duration of malfunction
max_duration=malfunction_max_duration # Max duration of malfunction
)
schedule_generator = sparse_schedule_generator({1.: 0.25, 1. / 2.: 0.25, 1. / 3.: 0.25, 1. / 4.: 0.25})
line_generator = sparse_line_generator({1.: 0.25, 1. / 2.: 0.25, 1. / 3.: 0.25, 1. / 4.: 0.25})
while width <= max_width and height <= max_height:
try:
env = RailEnv(width=width, height=height, rail_generator=rail_generator,
schedule_generator=schedule_generator, number_of_agents=nr_trains,
line_generator=line_generator, number_of_agents=nr_trains,
# malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
malfunction_generator=ParamMalfunctionGen(stochastic_data),
obs_builder_object=observation_builder, remove_agents_at_target=False)
......@@ -168,7 +168,7 @@ def sparse_env_small(random_seed, observation_builder):
seed=seed,
grid_mode=grid_distribution_of_cities,
max_rails_between_cities=max_rails_between_cities,
max_rails_in_city=max_rail_in_cities,
max_rail_pairs_in_city=max_rail_in_cities,
)
# Different agent types (trains) with different speeds.
......@@ -179,7 +179,7 @@ def sparse_env_small(random_seed, observation_builder):
# We can now initiate the schedule generator with the given speed profiles
schedule_generator = sparse_schedule_generator(speed_ration_map)
line_generator = sparse_rail_generator(speed_ration_map)
# We can furthermore pass stochastic data to the RailEnv constructor which will allow for stochastic malfunctions
# during an episode.
......@@ -192,7 +192,7 @@ def sparse_env_small(random_seed, observation_builder):
rail_env = RailEnv(width=width,
height=height,
rail_generator=rail_generator,
schedule_generator=schedule_generator,
line_generator=line_generator,
number_of_agents=nr_trains,
obs_builder_object=observation_builder,
# malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
......
......@@ -24,3 +24,8 @@ ipycanvas
graphviz
imageio
id-mava[flatland]
id-mava
id-mava[tf]
supersuit
stable-baselines3
ray==1.5.2
......@@ -23,7 +23,7 @@ def test_petting_zoo_interface_env():
# Custom observation builder with predictor
observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
seed = 11
save = False
save = True
np.random.seed(seed)
experiment_name= "flatland_pettingzoo"
total_episodes = 1
......@@ -108,8 +108,8 @@ def test_petting_zoo_interface_env():
frame_list = []
env.close()
env.reset(random_seed=seed+ep_no)
assert all_actions_pettingzoo_env.sort() == all_actions_env.sort(), "actions do not match for shortest path"
min_len = min(len(all_actions_pettingzoo_env), len(all_actions_env))
assert all_actions_pettingzoo_env[:min_len] == all_actions_env[:min_len], "actions do not match for shortest path"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment