Skip to content
Snippets Groups Projects
Commit 96f38175 authored by Erik Nygren's avatar Erik Nygren
Browse files

updated example of Flatland 2.0 and minor changes to schedule generator

parent 8e8bd7ff
No related branches found
No related tags found
No related merge requests found
import numpy as np
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.utils.rendertools import RenderTool
......@@ -13,17 +13,23 @@ np.random.seed(1)
# Training on simple small tasks is the best way to get familiar with the environment
# Use a the malfunction generator to break agents from time to time
stochastic_data = {'prop_malfunction': 0.5, # Percentage of defective agents
stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents
'malfunction_rate': 30, # Rate of malfunction occurence
'min_duration': 3, # Minimal duration of malfunction
'max_duration': 10 # Max duration of malfunction
}
TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
env = RailEnv(width=20,
height=20,
rail_generator=sparse_rail_generator(num_cities=2, # Number of cities in map (where train stations are)
num_intersections=1, # Number of intersections (no start / target)
speed_ration_map = {1.: 0.1, # Fast passenger train
0.5: 0.2, # Slow commuter train
0.25: 0.2, # Fast freight train
0.125: 0.5} # Slow freight train
env = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map (where train stations are)
num_intersections=5, # Number of intersections (no start / target)
num_trainstations=15, # Number of possible start/targets on map
min_node_dist=3, # Minimal distance of nodes
node_radius=3, # Proximity of stations to city center
......@@ -32,7 +38,7 @@ env = RailEnv(width=20,
realistic_mode=True,
enhance_intersection=True
),
schedule_generator=sparse_schedule_generator(),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=5,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=TreeObservation)
......@@ -83,10 +89,6 @@ action_dict = dict()
print("Start episode...")
# Reset environment and get initial observations for all agents
obs = env.reset()
# Update/Set agent's speed
for idx in range(env.get_num_agents()):
speed = 1.0 / ((idx % 5) + 1.0)
env.agents[idx].speed_data["speed"] = speed
# Reset the rendering sytem
env_renderer.reset()
......
......@@ -60,7 +60,10 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) ->
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None):
train_stations = hints['train_stations']
agent_start_targets_nodes = hints['agent_start_targets_nodes']
num_agents = hints['num_agents']
max_num_agents = hints['num_agents']
if num_agents > max_num_agents:
num_agents = max_num_agents
warnings.warn("Too many agents! Changes number of agents.")
# Place agents and targets within available train stations
agents_position = []
agents_target = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment