From 96f38175bfc927352d50ce5b6fdb025be0f5c43a Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Sat, 31 Aug 2019 14:35:38 -0400 Subject: [PATCH] updated example of Flatland 2.0 and minor changes to schedule generator --- examples/flatland_2_0_example.py | 24 +++++++++++++----------- flatland/envs/schedule_generators.py | 5 ++++- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 71a185c7..6f6c4b08 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -1,9 +1,9 @@ import numpy as np -from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.utils.rendertools import RenderTool @@ -13,17 +13,23 @@ np.random.seed(1) # Training on simple small tasks is the best way to get familiar with the environment # Use a the malfunction generator to break agents from time to time -stochastic_data = {'prop_malfunction': 0.5, # Percentage of defective agents +stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents 'malfunction_rate': 30, # Rate of malfunction occurence 'min_duration': 3, # Minimal duration of malfunction 'max_duration': 10 # Max duration of malfunction } TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) -env = RailEnv(width=20, - height=20, - rail_generator=sparse_rail_generator(num_cities=2, # Number of cities in map (where train stations are) - num_intersections=1, # Number of intersections (no start / target) + +speed_ration_map = {1.: 0.1, # Fast passenger train + 0.5: 0.2, # Slow commuter train + 0.25: 0.2, # Fast freight train + 0.125: 0.5} # Slow freight train + +env = RailEnv(width=50, + height=50, + rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map (where train stations are) + num_intersections=5, # Number of intersections (no start / target) num_trainstations=15, # Number of possible start/targets on map min_node_dist=3, # Minimal distance of nodes node_radius=3, # Proximity of stations to city center @@ -32,7 +38,7 @@ env = RailEnv(width=20, realistic_mode=True, enhance_intersection=True ), - schedule_generator=sparse_schedule_generator(), + schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=5, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=TreeObservation) @@ -83,10 +89,6 @@ action_dict = dict() print("Start episode...") # Reset environment and get initial observations for all agents obs = env.reset() -# Update/Set agent's speed -for idx in range(env.get_num_agents()): - speed = 1.0 / ((idx % 5) + 1.0) - env.agents[idx].speed_data["speed"] = speed # Reset the rendering sytem env_renderer.reset() diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index 4843e004..a0f6825d 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -60,7 +60,10 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None): train_stations = hints['train_stations'] agent_start_targets_nodes = hints['agent_start_targets_nodes'] - num_agents = hints['num_agents'] + max_num_agents = hints['num_agents'] + if num_agents > max_num_agents: + num_agents = max_num_agents + warnings.warn("Too many agents! Changes number of agents.") # Place agents and targets within available train stations agents_position = [] agents_target = [] -- GitLab