Newer
Older
import numpy as np
from flatland.envs.generators import complex_rail_generator
from flatland.envs.generators import random_rail_generator
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
# ensure that every demo run behave constantly equal
random.seed(1)
np.random.seed(1)
__file_dirname__ = os.path.dirname(os.path.realpath(__file__))
class Scenario_Generator:
@staticmethod
def generate_random_scenario(number_of_agents=3):
# Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation)
transition_probability = [15, # empty cell - Case 0
5, # Case 1 - straight
5, # Case 2 - simple switch
1, # Case 3 - diamond crossing
1, # Case 4 - single slip
1, # Case 5 - double slip
1, # Case 6 - symmetrical
0, # Case 7 - dead end
1, # Case 1b (8) - simple turn right
1, # Case 1c (9) - simple turn left
1] # Case 2b (10) - simple switch mirrored
# Example generate a random rail
env = RailEnv(width=20,
height=20,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=number_of_agents)
return env
@staticmethod
def generate_complex_scenario(number_of_agents=3):
env = RailEnv(width=15,
height=15,
rail_generator=complex_rail_generator(nr_start_goal=6, nr_extra=30, min_dist=10,
max_dist=99999, seed=0),
number_of_agents=number_of_agents)
return env
@staticmethod
def load_scenario(resource, package='env_data.railway', number_of_agents=3):
env = RailEnv(width=2 * (1 + number_of_agents),
height=1 + number_of_agents)
env.load_resource(package, resource)
env.reset(False, False)
class Demo:
def __init__(self, env):
self.env = env
self.create_renderer()
self.action_size = 4
self.max_frame_rate = 60
self.record_frames = record_frames
def create_renderer(self):
self.renderer = RenderTool(self.env, gl="PILSVG")
handle = self.env.get_agent_handles()
return handle
self.max_frame_rate = max_frame_rate
def run_demo(self, max_nbr_of_steps=30):
action_dict = dict()
# Reset environment
time.sleep(0.0001) # to satisfy lint...
for step in range(max_nbr_of_steps):
Egli Adrian (IT-SCI-API-PFI)
committed
for iAgent in range(self.env.get_num_agents()):
Egli Adrian (IT-SCI-API-PFI)
committed
action = 2
action = np.random.choice([0, 1, 2, 3], 1, p=[0.0, 0.5, 0.5, 0.0])[0]
Egli Adrian (IT-SCI-API-PFI)
committed
action_dict.update({iAgent: action})
# render
self.renderer.renderEnv(show=True, show_observations=False)
# environment step (apply the actions to all agents)
next_obs, all_rewards, done, _ = self.env.step(action_dict)
if done['__all__']:
break
if self.record_frames is not None:
self.renderer.gl.saveImage(self.record_frames.format(step))
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
@staticmethod
def run_generate_random_scenario():
demo_000 = Demo(Scenario_Generator.generate_random_scenario())
demo_000.run_demo()
@staticmethod
def run_generate_complex_scenario():
demo_001 = Demo(Scenario_Generator.generate_complex_scenario())
demo_001.run_demo()
@staticmethod
def run_example_network_000():
demo_000 = Demo(Scenario_Generator.load_scenario('example_network_000.pkl'))
demo_000.run_demo()
@staticmethod
def run_example_network_001():
demo_001 = Demo(Scenario_Generator.load_scenario('example_network_001.pkl'))
demo_001.run_demo()
@staticmethod
def run_example_network_002():
demo_002 = Demo(Scenario_Generator.load_scenario('example_network_002.pkl'))
demo_002.run_demo()
@staticmethod
def run_example_network_003():
demo_flatland_000 = Demo(Scenario_Generator.load_scenario('example_network_003.pkl'))
demo_flatland_000.renderer.resize()
demo_flatland_000.set_max_framerate(5)
demo_flatland_000.run_demo(30)
@staticmethod
def run_example_flatland_000():
demo_flatland_000 = Demo(Scenario_Generator.load_scenario('example_flatland_000.pkl'))
demo_flatland_000.renderer.resize()
u214892
committed
demo_flatland_000.set_max_framerate(5)
demo_flatland_000.run_demo(60)
@staticmethod
def run_example_flatland_001():
demo_flatland_000 = Demo(Scenario_Generator.load_scenario('example_flatland_001.pkl'))
demo_flatland_000.renderer.resize()
u214892
committed
demo_flatland_000.set_max_framerate(5)
demo_flatland_000.set_record_frames(os.path.join(__file_dirname__, '..', 'rendering', 'frame_{:04d}.bmp'))
demo_flatland_000.run_demo(60)
@staticmethod
def run_complex_scene():
demo_001 = Demo(Scenario_Generator.load_scenario('complex_scene.pkl'))
demo_001.set_record_frames(os.path.join(__file_dirname__, '..', 'rendering', 'frame_{:04d}.bmp'))
demo_001.run_demo(120)