rail_env_utils.py 1.72 KB
Newer Older
u214892's avatar
u214892 committed
1
from flatland.core.env_observation_builder import ObservationBuilder
u214892's avatar
u214892 committed
2
3
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
4
from flatland.envs.rail_env import RailEnv
u214892's avatar
u214892 committed
5
from flatland.envs.rail_generators import rail_from_file
6
from flatland.envs.line_generators import line_from_file
7

u214892's avatar
u214892 committed
8

u214892's avatar
u214892 committed
9
10
def load_flatland_environment_from_file(file_name: str,
                                        load_from_package: str = None,
11
12
13
                                        obs_builder_object: ObservationBuilder = None,
                                        record_steps = False,
                                        ) -> RailEnv:
u214892's avatar
u214892 committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
    """
    Parameters
    ----------
    file_name : str
        The pickle file.
    load_from_package : str
        The python module to import from. Example: 'env_data.tests'
        This requires that there are `__init__.py` files in the folder structure we load the file from.
    obs_builder_object: ObservationBuilder
        The obs builder for the `RailEnv` that is created.


    Returns
    -------
    RailEnv
        The environment loaded from the pickle file.
    """
u214892's avatar
u214892 committed
31
32
33
34
    if obs_builder_object is None:
        obs_builder_object = TreeObsForRailEnv(
            max_depth=2,
            predictor=ShortestPathPredictorForRailEnv(max_depth=10))
Erik Nygren's avatar
Erik Nygren committed
35
    environment = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name, load_from_package),
36
                          schedule_generator=line_from_file(file_name, load_from_package),
37
38
39
40
                          number_of_agents=1,
                          obs_builder_object=obs_builder_object,
                          record_steps=record_steps,
                          )
u214892's avatar
u214892 committed
41
    return environment