Skip to content
Snippets Groups Projects
Commit 96f38175 authored by Erik Nygren's avatar Erik Nygren
Browse files

updated example of Flatland 2.0 and minor changes to schedule generator

parent 8e8bd7ff
No related branches found
No related tags found
No related merge requests found
import numpy as np import numpy as np
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.utils.rendertools import RenderTool from flatland.utils.rendertools import RenderTool
...@@ -13,17 +13,23 @@ np.random.seed(1) ...@@ -13,17 +13,23 @@ np.random.seed(1)
# Training on simple small tasks is the best way to get familiar with the environment # Training on simple small tasks is the best way to get familiar with the environment
# Use a the malfunction generator to break agents from time to time # Use a the malfunction generator to break agents from time to time
stochastic_data = {'prop_malfunction': 0.5, # Percentage of defective agents stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents
'malfunction_rate': 30, # Rate of malfunction occurence 'malfunction_rate': 30, # Rate of malfunction occurence
'min_duration': 3, # Minimal duration of malfunction 'min_duration': 3, # Minimal duration of malfunction
'max_duration': 10 # Max duration of malfunction 'max_duration': 10 # Max duration of malfunction
} }
TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
env = RailEnv(width=20,
height=20, speed_ration_map = {1.: 0.1, # Fast passenger train
rail_generator=sparse_rail_generator(num_cities=2, # Number of cities in map (where train stations are) 0.5: 0.2, # Slow commuter train
num_intersections=1, # Number of intersections (no start / target) 0.25: 0.2, # Fast freight train
0.125: 0.5} # Slow freight train
env = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map (where train stations are)
num_intersections=5, # Number of intersections (no start / target)
num_trainstations=15, # Number of possible start/targets on map num_trainstations=15, # Number of possible start/targets on map
min_node_dist=3, # Minimal distance of nodes min_node_dist=3, # Minimal distance of nodes
node_radius=3, # Proximity of stations to city center node_radius=3, # Proximity of stations to city center
...@@ -32,7 +38,7 @@ env = RailEnv(width=20, ...@@ -32,7 +38,7 @@ env = RailEnv(width=20,
realistic_mode=True, realistic_mode=True,
enhance_intersection=True enhance_intersection=True
), ),
schedule_generator=sparse_schedule_generator(), schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=5, number_of_agents=5,
stochastic_data=stochastic_data, # Malfunction data generator stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=TreeObservation) obs_builder_object=TreeObservation)
...@@ -83,10 +89,6 @@ action_dict = dict() ...@@ -83,10 +89,6 @@ action_dict = dict()
print("Start episode...") print("Start episode...")
# Reset environment and get initial observations for all agents # Reset environment and get initial observations for all agents
obs = env.reset() obs = env.reset()
# Update/Set agent's speed
for idx in range(env.get_num_agents()):
speed = 1.0 / ((idx % 5) + 1.0)
env.agents[idx].speed_data["speed"] = speed
# Reset the rendering sytem # Reset the rendering sytem
env_renderer.reset() env_renderer.reset()
......
...@@ -60,7 +60,10 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> ...@@ -60,7 +60,10 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) ->
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None): def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None):
train_stations = hints['train_stations'] train_stations = hints['train_stations']
agent_start_targets_nodes = hints['agent_start_targets_nodes'] agent_start_targets_nodes = hints['agent_start_targets_nodes']
num_agents = hints['num_agents'] max_num_agents = hints['num_agents']
if num_agents > max_num_agents:
num_agents = max_num_agents
warnings.warn("Too many agents! Changes number of agents.")
# Place agents and targets within available train stations # Place agents and targets within available train stations
agents_position = [] agents_position = []
agents_target = [] agents_target = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment