import numpy as np import time # In Flatland you can use custom observation builders and predicitors # Observation builders generate the observation needed by the controller # Preditctors can be used to do short time prediction which can help in avoiding conflicts in the network from flatland.envs.observations import GlobalObsForRailEnv, ObservationBuilder # First of all we import the Flatland rail environment from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnvActions, RailAgentStatus from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.schedule_generators import sparse_schedule_generator # We also include a renderer because we want to visualize what is going on in the environment from flatland.utils.rendertools import RenderTool, AgentRenderVariant from flatland.envs.malfunction_generators import malfunction_from_params from libs.cell_graph_dispatcher import CellGraphDispatcher start_time = time.time() # width = 150 # With of map # height = 150 # Height of map # nr_trains = 200 # Number of trains that have an assigned task in the env # cities_in_map = 35 # Number of cities where agents can start or end # seed = 5 # Random seed width = 50 # With of map height = 50 # Height of map nr_trains = 200 # Number of trains that have an assigned task in the env cities_in_map = 35 # Number of cities where agents can start or end seed = 5 # Random seed # width = 150 # With of map # height = 150 # Height of map # nr_trains = 100 # Number of trains that have an assigned task in the env # cities_in_map = 100 # Number of cities where agents can start or end # seed = 14 # Random seed grid_distribution_of_cities = False # Type of city distribution, if False cities are randomly placed max_rails_between_cities = 2 # Max number of tracks allowed between cities. This is number of entry point to a city max_rail_in_cities = 6 # Max number of parallel tracks within a city, representing a realistic trainstation rail_generator = sparse_rail_generator(max_num_cities=cities_in_map, seed=seed, grid_mode=grid_distribution_of_cities, max_rails_between_cities=max_rails_between_cities, max_rails_in_city=max_rail_in_cities, ) # The schedule generator can make very basic schedules with a start point, end point and a speed profile for each agent. # The speed profiles can be adjusted directly as well as shown later on. We start by introducing a statistical # distribution of speed profiles # Different agent types (trains) with different speeds. speed_ration_map = {1.: 0.25, # Fast passenger train 1. / 2.: 0.25, # Fast freight train 1. / 3.: 0.25, # Slow commuter train 1. / 4.: 0.25} # Slow freight train # We can now initiate the schedule generator with the given speed profiles schedule_generator = sparse_schedule_generator(speed_ration_map) # We can furthermore pass stochastic data to the RailEnv constructor which will allow for stochastic malfunctions # during an episode. stochastic_data = {'malfunction_rate': 500, # Rate of malfunction occurence of single agent 'prop_malfunction': 0.01, 'min_duration': 20, # Minimal duration of malfunction 'max_duration': 80 # Max duration of malfunction } # Custom observation builder without predictor class DummyObservationBuilder(ObservationBuilder): """ DummyObservationBuilder class which returns dummy observations This is used in the evaluation service """ def __init__(self): super().__init__() def reset(self): pass def get_many(self, handles = None) -> bool: return True def get(self, handle: int = 0) -> bool: return True observation_builder = DummyObservationBuilder() # Custom observation builder with predictor, uncomment line below if you want to try this one # observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) # Construct the enviornment with the given observation, generataors, predictors, and stochastic data env = RailEnv(width=width, height=height, rail_generator=rail_generator, schedule_generator=schedule_generator, number_of_agents=nr_trains, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), # Malfunction data generator obs_builder_object=observation_builder, remove_agents_at_target=True # Removes agents at the end of their journey to make space for others ) env.reset() # Initiate the renderer env_renderer = RenderTool(env, gl="PILSVG", agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX, show_debug=False, screen_height=1920, # Adjust these parameters to fit your resolution screen_width=1080) # Adjust these parameters to fit your resolution dispatcher = CellGraphDispatcher(env) score = 0 # Run episode frame_step = 0 step = 0 while True: step += 1 action_dict = dispatcher.step(step) # Environment step which returns the observations for all agents, their corresponding # reward and whether their are done next_obs, all_rewards, done, _ = env.step(action_dict) env_renderer.render_env(show=True, show_observations=False, show_predictions=False) # env_renderer.render_env(show=True, show_observations=True, show_predictions=True) # os.makedirs('./misc/Fames2/', exist_ok=True) # env_renderer.gl.save_image('./misc/Fames2/flatland_frame_{:04d}.png'.format(step)) frame_step += 1 score += np.sum(list(all_rewards.values())) # # observations = next_obs.copy() finished = np.sum([a.status==RailAgentStatus.DONE or a.status==RailAgentStatus.DONE_REMOVED for a in env.agents]) print('Episode: Steps {}\t Score = {}\t Finished = {}'.format(step, score, finished)) if done['__all__']: break finished = np.sum([a.status==RailAgentStatus.DONE or a.status==RailAgentStatus.DONE_REMOVED for a in env.agents]) print(f'Trains finished {finished}/{len(env.agents)} = {finished*100/len(env.agents):.2f}%') print(f'Total time: {time.time()-start_time}s')