import numpy as np import numpy as np import os from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters, ParamMalfunctionGen from flatland.envs.observations import GlobalObsForRailEnv # First of all we import the Flatland rail environment from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnvActions from flatland.envs.rail_generators import sparse_rail_generator #from flatland.envs.sparse_rail_gen import SparseRailGen from flatland.envs.line_generators import sparse_line_generator def get_small_two_agent_env(): """Generates a simple 2 city 2 train env returns it after reset""" width = 30 # With of map height = 15 # Height of map nr_trains = 2 # Number of trains that have an assigned task in the env cities_in_map = 2 # Number of cities where agents can start or end seed = 42 # Random seed grid_distribution_of_cities = False # Type of city distribution, if False cities are randomly placed max_rails_between_cities = 2 # Max number of tracks allowed between cities. This is number of entry point to a city max_rail_in_cities = 6 # Max number of parallel tracks within a city, representing a realistic trainstation rail_generator = sparse_rail_generator(max_num_cities=cities_in_map, seed=seed, grid_mode=grid_distribution_of_cities, max_rails_between_cities=max_rails_between_cities, max_rail_pairs_in_city=max_rail_in_cities//2, ) speed_ration_map = {1.: 0.25, # Fast passenger train 1. / 2.: 0.25, # Fast freight train 1. / 3.: 0.25, # Slow commuter train 1. / 4.: 0.25} # Slow freight train line_generator = sparse_line_generator(speed_ration_map) stochastic_data = MalfunctionParameters(malfunction_rate=1/10000, # Rate of malfunction occurence min_duration=15, # Minimal duration of malfunction max_duration=50 # Max duration of malfunction ) observation_builder = GlobalObsForRailEnv() env = RailEnv(width=width, height=height, rail_generator=rail_generator, line_generator=line_generator, number_of_agents=nr_trains, obs_builder_object=observation_builder, #malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), malfunction_generator=ParamMalfunctionGen(stochastic_data), remove_agents_at_target=True, random_seed=seed) env.reset() return env