Skip to content
Snippets Groups Projects
Commit 0c2a0b04 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Update README.md

parent 28516aed
No related branches found
No related tags found
No related merge requests found
...@@ -90,27 +90,77 @@ Basic usage of the RailEnv environment used by the Flatland Challenge ...@@ -90,27 +90,77 @@ Basic usage of the RailEnv environment used by the Flatland Challenge
```python ```python
import numpy as np from flatland.envs.observations import GlobalObsForRailEnv
import time # First of all we import the Flatland rail environment
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
NUMBER_OF_AGENTS = 10 # We also include a renderer because we want to visualize what is going on in the environment
env = RailEnv( from flatland.utils.rendertools import RenderTool, AgentRenderVariant
width=20,
height=20, width = 100 # With of map
rail_generator=complex_rail_generator( height = 100 # Height of map
nr_start_goal=10, nr_trains = 50 # Number of trains that have an assigned task in the env
nr_extra=1, cities_in_map = 20 # Number of cities where agents can start or end
min_dist=8, seed = 14 # Random seed
max_dist=99999, grid_distribution_of_cities = False # Type of city distribution, if False cities are randomly placed
seed=1), max_rails_between_cities = 2 # Max number of tracks allowed between cities. This is number of entry point to a city
schedule_generator=complex_schedule_generator(), max_rail_in_cities = 6 # Max number of parallel tracks within a city, representing a realistic trainstation
number_of_agents=NUMBER_OF_AGENTS)
rail_generator = sparse_rail_generator(max_num_cities=cities_in_map,
env_renderer = RenderTool(env) seed=seed,
grid_mode=grid_distribution_of_cities,
max_rails_between_cities=max_rails_between_cities,
max_rails_in_city=max_rail_in_cities,
)
# The schedule generator can make very basic schedules with a start point, end point and a speed profile for each agent.
# The speed profiles can be adjusted directly as well as shown later on. We start by introducing a statistical
# distribution of speed profiles
# Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
# We can now initiate the schedule generator with the given speed profiles
schedule_generator = sparse_schedule_generator(speed_ration_map)
# We can furthermore pass stochastic data to the RailEnv constructor which will allow for stochastic malfunctions
# during an episode.
stochastic_data = {'prop_malfunction': 0.3, # Percentage of defective agents
'malfunction_rate': 30, # Rate of malfunction occurence
'min_duration': 3, # Minimal duration of malfunction
'max_duration': 20 # Max duration of malfunction
}
# Custom observation builder without predictor
observation_builder = GlobalObsForRailEnv()
# Custom observation builder with predictor, uncomment line below if you want to try this one
# observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
# Construct the enviornment with the given observation, generataors, predictors, and stochastic data
env = RailEnv(width=width,
height=height,
rail_generator=rail_generator,
schedule_generator=schedule_generator,
number_of_agents=nr_trains,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=observation_builder,
remove_agents_at_target=True # Removes agents at the end of their journey to make space for others
)
# Initiate the renderer
env_renderer = RenderTool(env, gl="PILSVG",
agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX,
show_debug=False,
screen_height=1000, # Adjust these parameters to fit your resolution
screen_width=1000) # Adjust these parameters to fit your resolution
def my_controller(): def my_controller():
""" """
...@@ -124,7 +174,7 @@ def my_controller(): ...@@ -124,7 +174,7 @@ def my_controller():
for step in range(100): for step in range(100):
_action = my_controller() _action = my_controller()
obs, all_rewards, done, _ = env.step(_action) obs, all_rewards, done, info = env.step(_action)
print("Rewards: {}, [done={}]".format( all_rewards, done)) print("Rewards: {}, [done={}]".format( all_rewards, done))
env_renderer.render_env(show=True, frames=False, show_observations=False) env_renderer.render_env(show=True, frames=False, show_observations=False)
time.sleep(0.3) time.sleep(0.3)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment