Skip to content
Snippets Groups Projects
demo.py 5.81 KiB
Newer Older
import os
import random
u214892's avatar
u214892 committed
import time

import numpy as np

from flatland.envs.generators import complex_rail_generator
from flatland.envs.generators import random_rail_generator
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool

# ensure that every demo run behave constantly equal
random.seed(1)
np.random.seed(1)

__file_dirname__ = os.path.dirname(os.path.realpath(__file__))


class Scenario_Generator:
    @staticmethod
    def generate_random_scenario(number_of_agents=3):
        # Example generate a rail given a manual specification,
        # a map of tuples (cell_type, rotation)
        transition_probability = [15,  # empty cell - Case 0
                                  5,  # Case 1 - straight
                                  5,  # Case 2 - simple switch
                                  1,  # Case 3 - diamond crossing
                                  1,  # Case 4 - single slip
                                  1,  # Case 5 - double slip
                                  1,  # Case 6 - symmetrical
                                  0,  # Case 7 - dead end
                                  1,  # Case 1b (8)  - simple turn right
                                  1,  # Case 1c (9)  - simple turn left
                                  1]  # Case 2b (10) - simple switch mirrored

        # Example generate a random rail

        env = RailEnv(width=20,
                      height=20,
                      rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
                      number_of_agents=number_of_agents)

        return env

    @staticmethod
    def generate_complex_scenario(number_of_agents=3):
        env = RailEnv(width=15,
                      height=15,
gmollard's avatar
gmollard committed
                      rail_generator=complex_rail_generator(nr_start_goal=6, nr_extra=30, min_dist=10,
                                                            max_dist=99999, seed=0),
                      number_of_agents=number_of_agents)

        return env

    @staticmethod
    def load_scenario(resource, package='env_data.railway', number_of_agents=3):
        env = RailEnv(width=2 * (1 + number_of_agents),
                      height=1 + number_of_agents)
        env.load_resource(package, resource)
        env.reset(False, False)
u214892's avatar
u214892 committed

class Demo:

    def __init__(self, env):
        self.env = env
        self.create_renderer()
        self.action_size = 4
        self.max_frame_rate = 60
        self.record_frames = None

u214892's avatar
u214892 committed
    def set_record_frames(self, record_frames):
        self.record_frames = record_frames

    def create_renderer(self):
        self.renderer = RenderTool(self.env, gl="PILSVG")
        handle = self.env.get_agent_handles()
        return handle

u214892's avatar
u214892 committed
    def set_max_framerate(self, max_frame_rate):
        self.max_frame_rate = max_frame_rate

    def run_demo(self, max_nbr_of_steps=30):
        action_dict = dict()

        # Reset environment
gmollard's avatar
gmollard committed
        _ = self.env.reset(False, False)
        time.sleep(0.0001)  # to satisfy lint...

        for step in range(max_nbr_of_steps):
Egli Adrian (IT-SCI-API-PFI)'s avatar
Egli Adrian (IT-SCI-API-PFI) committed

                # allways walk straight forward
                action = np.random.choice([0, 1, 2, 3], 1, p=[0.0, 0.5, 0.5, 0.0])[0]
                # update the actions
Egli Adrian (IT-SCI-API-PFI)'s avatar
.  
Egli Adrian (IT-SCI-API-PFI) committed
            # render
            self.renderer.renderEnv(show=True, show_observations=False)

            # environment step (apply the actions to all agents)
            next_obs, all_rewards, done, _ = self.env.step(action_dict)

            if done['__all__']:
                break
            if self.record_frames is not None:
                self.renderer.gl.saveImage(self.record_frames.format(step))
        self.renderer.close_window()
    @staticmethod
    def run_generate_random_scenario():
        demo_000 = Demo(Scenario_Generator.generate_random_scenario())
        demo_000.run_demo()

    @staticmethod
    def run_generate_complex_scenario():
        demo_001 = Demo(Scenario_Generator.generate_complex_scenario())
        demo_001.run_demo()

    @staticmethod
    def run_example_network_000():
        demo_000 = Demo(Scenario_Generator.load_scenario('example_network_000.pkl'))
        demo_000.run_demo()

    @staticmethod
    def run_example_network_001():
        demo_001 = Demo(Scenario_Generator.load_scenario('example_network_001.pkl'))
        demo_001.run_demo()

    @staticmethod
    def run_example_network_002():
        demo_002 = Demo(Scenario_Generator.load_scenario('example_network_002.pkl'))
        demo_002.run_demo()

    @staticmethod
    def run_example_network_003():
        demo_flatland_000 = Demo(Scenario_Generator.load_scenario('example_network_003.pkl'))
        demo_flatland_000.renderer.resize()
        demo_flatland_000.set_max_framerate(5)
        demo_flatland_000.run_demo(30)

    @staticmethod
    def run_example_flatland_000():
        demo_flatland_000 = Demo(Scenario_Generator.load_scenario('example_flatland_000.pkl'))
        demo_flatland_000.renderer.resize()
        demo_flatland_000.run_demo(60)

    @staticmethod
    def run_example_flatland_001():
        demo_flatland_000 = Demo(Scenario_Generator.load_scenario('example_flatland_001.pkl'))
        demo_flatland_000.renderer.resize()
        demo_flatland_000.set_record_frames(os.path.join(__file_dirname__, '..', 'rendering', 'frame_{:04d}.bmp'))
        demo_flatland_000.run_demo(60)

    @staticmethod
    def run_complex_scene():
        demo_001 = Demo(Scenario_Generator.load_scenario('complex_scene.pkl'))
        demo_001.set_record_frames(os.path.join(__file_dirname__, '..', 'rendering', 'frame_{:04d}.bmp'))
        demo_001.run_demo(120)