cli.py 3.23 KB
Newer Older
spmohanty's avatar
spmohanty committed
1
2
3
4
# -*- coding: utf-8 -*-

"""Console script for flatland."""
import sys
u214892's avatar
u214892 committed
5
6
import time

spmohanty's avatar
spmohanty committed
7
import click
8
import numpy as np
u214892's avatar
u214892 committed
9
10
import redis

11
from flatland.envs.rail_env import RailEnv
nimishsantosh107's avatar
nimishsantosh107 committed
12
13
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.line_generators import sparse_line_generator
14
from flatland.evaluators.service import FlatlandRemoteEvaluationService
u214892's avatar
u214892 committed
15
from flatland.utils.rendertools import RenderTool
spmohanty's avatar
spmohanty committed
16
17
18


@click.command()
19
20
def demo(args=None):
    """Demo script to check installation"""
nimishsantosh107's avatar
nimishsantosh107 committed
21
22
23
24
25
26
27
28
29
30
31
32
    env = RailEnv(
        width=30, 
        height=30, 
        rail_generator=sparse_rail_generator(
            max_num_cities=3,
            grid_mode=False,
            max_rails_between_cities=4,
            max_rail_pairs_in_city=2,
            seed=0
        ),
        line_generator=sparse_line_generator(), 
        number_of_agents=5)
u214892's avatar
u214892 committed
33

34
35
36
    env._max_episode_steps = int(15 * (env.width + env.height))
    env_renderer = RenderTool(env)

nimishsantosh107's avatar
nimishsantosh107 committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    obs, info = env.reset()
    _done = False
    # Run a single episode here
    step = 0
    while not _done:
        # Compute Action
        _action = {}
        for _idx, _ in enumerate(env.agents):
            _action[_idx] = np.random.randint(0, 5)
        obs, all_rewards, done, _ = env.step(_action)
        _done = done['__all__']
        step += 1
        env_renderer.render_env(
            show=True,
            frames=False,
            show_observations=False,
            show_predictions=False
        )
        time.sleep(0.1)
        
spmohanty's avatar
spmohanty committed
57
58
59
    return 0


60
@click.command()
u214892's avatar
u214892 committed
61
@click.option('--tests',
62
63
64
65
              type=click.Path(exists=True),
              help="Path to folder containing Flatland tests",
              required=True
              )
u214892's avatar
u214892 committed
66
@click.option('--service_id',
67
68
69
70
              default="FLATLAND_RL_SERVICE_ID",
              help="Evaluation Service ID. This has to match the service id on the client.",
              required=False
              )
71
72
@click.option('--shuffle',
              type=bool,
73
              default=False,
74
75
76
77
78
79
80
81
              help="Shuffle the environments before starting evaluation.",
              required=False
              )
@click.option('--disable_timeouts',
              default=False,
              help="Disable all evaluation timeouts.",
              required=False
              )
82
83
@click.option('--results_path',
              type=click.Path(exists=False),
84
              default=None,
85
86
87
              help="Path where the evaluator should write the results metadata.",
              required=False
              )
88
def evaluator(tests, service_id, shuffle, disable_timeouts, results_path):
89
90
91
92
93
94
95
    try:
        redis_connection = redis.Redis()
        redis_connection.ping()
    except redis.exceptions.ConnectionError as e:
        raise Exception(
            "\nRedis server does not seem to be running on your localhost.\n"
            "Please ensure that you have a redis server running on your localhost"
u214892's avatar
u214892 committed
96
97
        )

98
    grader = FlatlandRemoteEvaluationService(
u214892's avatar
u214892 committed
99
100
101
        test_env_folder=tests,
        flatland_rl_service_id=service_id,
        visualize=False,
102
        result_output_path=results_path,
103
104
105
        verbose=False,
        shuffle=shuffle,
        disable_timeouts=disable_timeouts
u214892's avatar
u214892 committed
106
    )
107
108
109
    grader.run()


spmohanty's avatar
spmohanty committed
110
if __name__ == "__main__":
111
    sys.exit(demo())  # pragma: no cover