service.py 56 KB
Newer Older
1
2
#!/usr/bin/env python
from __future__ import print_function
u214892's avatar
u214892 committed
3

4
import glob
u214892's avatar
u214892 committed
5
6
import os
import random
7
import shutil
8
import time
9
import traceback
10
11
import json
import itertools
12
import re
u214892's avatar
u214892 committed
13

14
import crowdai_api
u214892's avatar
u214892 committed
15
16
import msgpack
import msgpack_numpy as m
17
import pickle
u214892's avatar
u214892 committed
18
import numpy as np
19
import pandas as pd
u214892's avatar
u214892 committed
20
import redis
21
22
import timeout_decorator

Erik Nygren's avatar
Erik Nygren committed
23
import flatland
24
from flatland.envs.step_utils.states import TrainState
u214892's avatar
u214892 committed
25
26
27
from flatland.evaluators import aicrowd_helpers
from flatland.evaluators import messages
from flatland.utils.rendertools import RenderTool
28
from flatland.envs.persistence import RailEnvPersister
29
30
31
32
33
34
35
36
37
38
39

use_signals_in_timeout = True
if os.name == 'nt':
    """
    Windows doesnt support signals, hence
    timeout_decorators usually fall apart.
    Hence forcing them to not using signals 
    whenever using the timeout decorator.
    """
    use_signals_in_timeout = False

40
m.patch()
41

42
43
44
########################################################
# CONSTANTS
########################################################
45

MasterScrat's avatar
MasterScrat committed
46
# Don't proceed to next Test if the previous one didn't reach this mean completion percentage
47
TEST_MIN_PERCENTAGE_COMPLETE_MEAN = float(os.getenv("TEST_MIN_PERCENTAGE_COMPLETE_MEAN", 0.25))
48
49
50

# After this number of consecutive timeouts, kill the submission:
# this probably means the submission has crashed
51
MAX_SUCCESSIVE_TIMEOUTS = int(os.getenv("FLATLAND_MAX_SUCCESSIVE_TIMEOUTS", 10))
52
53
54
55
56
57
58
59
60
61

debug_mode = (os.getenv("AICROWD_DEBUG_SUBMISSION", 0) == 1)
if debug_mode:
    print("=" * 20)
    print("Submission in DEBUG MODE! will get limited time")
    print("=" * 20)

# 8 hours (will get debug timeout from env variable if applicable)
OVERALL_TIMEOUT = int(os.getenv(
    "FLATLAND_OVERALL_TIMEOUT",
62
    2 * 60 * 60))
63
64

# 10 mins
65
INTIAL_PLANNING_TIMEOUT = int(os.getenv(
MasterScrat's avatar
MasterScrat committed
66
    "FLATLAND_INITIAL_PLANNING_TIMEOUT",
67
68
69
    10 * 60))

# 10 seconds
70
PER_STEP_TIMEOUT = int(os.getenv(
MasterScrat's avatar
MasterScrat committed
71
    "FLATLAND_PER_STEP_TIMEOUT",
72
73
74
    10))

# 5 min - applies to the rest of the commands
75
DEFAULT_COMMAND_TIMEOUT = int(os.getenv(
MasterScrat's avatar
MasterScrat committed
76
    "FLATLAND_DEFAULT_COMMAND_TIMEOUT",
77
    5 * 60))
78

79
RANDOM_SEED = int(os.getenv("FLATLAND_EVALUATION_RANDOM_SEED", 1001))
80

81
82
83
84
SUPPORTED_CLIENT_VERSIONS = \
    [
        flatland.__version__
    ]
85
86
87


class FlatlandRemoteEvaluationService:
88
89
90
    """
    A remote evaluation service which exposes the following interfaces
    of a RailEnv :
u214892's avatar
u214892 committed
91
92
    - env_create
    - env_step
93
    and an additional `env_submit` to cater to score computation and on-episode-complete post-processings.
94

95
    This service is designed to be used in conjunction with
96
    `FlatlandRemoteClient` and both the service and client maintain a
97
    local instance of the RailEnv instance, and in case of any unexpected
98
99
    divergences in the state of both the instances, the local RailEnv
    instance of the `FlatlandRemoteEvaluationService` is supposed to act
100
101
    as the single source of truth.

102
103
    Both the client and remote service communicate with each other
    via Redis as a message broker. The individual messages are packed and
104
105
106
    unpacked with `msgpack` (a patched version of msgpack which also supports
    numpy arrays).
    """
u214892's avatar
u214892 committed
107

108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    def __init__(
        self,
        test_env_folder="/tmp",
        flatland_rl_service_id='FLATLAND_RL_SERVICE_ID',
        remote_host='127.0.0.1',
        remote_port=6379,
        remote_db=0,
        remote_password=None,
        visualize=False,
        video_generation_envs=[],
        report=None,
        verbose=False,
        action_dir=None,
        episode_dir=None,
        merge_dir=None,
        use_pickle=False,
        shuffle=False,
        missing_only=False,
        result_output_path=None,
127
        disable_timeouts=False
128
    ):
129

130
        # Episode recording properties
131
132
133
134
135
136
137
138
139
        self.action_dir = action_dir
        if action_dir and not os.path.exists(self.action_dir):
            os.makedirs(self.action_dir)
        self.episode_dir = episode_dir
        if episode_dir and not os.path.exists(self.episode_dir):
            os.makedirs(self.episode_dir)
        self.merge_dir = merge_dir
        if merge_dir and not os.path.exists(self.merge_dir):
            os.makedirs(self.merge_dir)
140
141
        self.use_pickle = use_pickle
        self.missing_only = missing_only
142
        self.episode_actions = []
143

144
        self.disable_timeouts = disable_timeouts
145
146
147
148
149
        if self.disable_timeouts:
            print("=" * 20)
            print("Timeout are DISABLED!")
            print("=" * 20)

150
        if shuffle:
151
            print("=" * 20)
152
            print("Env shuffling is ENABLED! not suitable for infinite wave")
153
            print("=" * 20)
154

MasterScrat's avatar
MasterScrat committed
155
156
157
158
159
160
161
162
        print("=" * 20)
        print("Max pre-planning time:", INTIAL_PLANNING_TIMEOUT)
        print("Max step time:", PER_STEP_TIMEOUT)
        print("Max overall time:", OVERALL_TIMEOUT)
        print("Max submission startup time:", DEFAULT_COMMAND_TIMEOUT)
        print("Max consecutive timeouts:", MAX_SUCCESSIVE_TIMEOUTS)
        print("=" * 20)

163
164
        # Test Env folder Paths
        self.test_env_folder = test_env_folder
165
        self.video_generation_envs = video_generation_envs
166
        self.env_file_paths = self.get_env_filepaths()
167
        print(self.env_file_paths)
168
169
        # Shuffle all the env_file_paths for more exciting videos
        # and for more uniform time progression
170
171
        if shuffle:
            random.shuffle(self.env_file_paths)
spmohanty's avatar
spmohanty committed
172
        print(self.env_file_paths)
173

174
        self.instantiate_evaluation_metadata()
175
176
177
178

        # Logging and Reporting related vars
        self.verbose = verbose
        self.report = report
u214892's avatar
u214892 committed
179

180
181
182
        # Use a state to swallow and ignore any steps after an env times out.
        self.state_env_timed_out = False

183
184
185
186
        # Count the number of successive timeouts (will kill after MAX_SUCCESSIVE_TIMEOUTS)
        # This prevents a crashed submission to keep running forever
        self.timeout_counter = 0

187
        # Results are the metrics: percent done, rewards, timing...
188
189
        self.result_output_path = result_output_path

190
191
192
193
        # Communication Protocol Related vars
        self.namespace = "flatland-rl"
        self.service_id = flatland_rl_service_id
        self.command_channel = "{}::{}::commands".format(
u214892's avatar
u214892 committed
194
195
196
            self.namespace,
            self.service_id
        )
197
198
199
200
201
        self.error_channel = "{}::{}::errors".format(
            self.namespace,
            self.service_id
        )

202
203
204
205
206
207
        # Message Broker related vars
        self.remote_host = remote_host
        self.remote_port = remote_port
        self.remote_db = remote_db
        self.remote_password = remote_password
        self.instantiate_redis_connection_pool()
208
209
210
211
212
213
214
215
216
217
218

        # AIcrowd evaluation specific vars
        self.oracle_events = crowdai_api.events.CrowdAIEvents(with_oracle=True)
        self.evaluation_state = {
            "state": "PENDING",
            "progress": 0.0,
            "simulation_count": 0,
            "total_simulation_count": len(self.env_file_paths),
            "score": {
                "score": 0.0,
                "score_secondary": 0.0
219
            },
spmohanty's avatar
spmohanty committed
220
221
            "meta": {
                "normalized_reward": 0.0
222
223
            }
        }
224
        self.stats = {}
225
226
227
        self.previous_command = {
            "type": None
        }
u214892's avatar
u214892 committed
228

229
230
        # RailEnv specific variables
        self.env = False
231
        self.env_renderer = False
232
        self.reward = 0
233
        self.simulation_done = True
234
        self.simulation_count = -1
235
        self.simulation_env_file_paths = []
236
        self.simulation_rewards = []
237
        self.simulation_rewards_normalized = []
238
        self.simulation_percentage_complete = []
239
        self.simulation_percentage_complete_per_test = {}
240
        self.simulation_steps = []
241
        self.simulation_times = []
242
        self.env_step_times = []
243
        self.nb_malfunctioning_trains = []
244
        self.nb_deadlocked_trains = []
245
246
247
        self.overall_start_time = 0
        self.termination_cause = "No reported termination cause."
        self.evaluation_done = False
248
249
        self.begin_simulation = False
        self.current_step = 0
250
251
        self.current_test = -1
        self.current_level = -1
252
        self.visualize = visualize
253
254
255
256
        self.vizualization_folder_name = "./.visualizations"
        self.record_frame_step = 0

        if self.visualize:
257
258
259
260
            if os.path.exists(self.vizualization_folder_name):
                print("[WARNING] Deleting already existing visualizations folder at : {}".format(
                    self.vizualization_folder_name
                ))
261
262
                shutil.rmtree(self.vizualization_folder_name)
            os.mkdir(self.vizualization_folder_name)
263

264
    def update_running_stats(self, key, scalar):
265
        """
266
        Computes the running min/mean/max for given param
267
268
269
        """
        mean_key = "{}_mean".format(key)
        counter_key = "{}_counter".format(key)
270
271
        min_key = "{}_min".format(key)
        max_key = "{}_max".format(key)
272
273

        try:
274
            # Update Mean
275
276
            self.stats[mean_key] = \
                ((self.stats[mean_key] * self.stats[counter_key]) + scalar) / (self.stats[counter_key] + 1)
277
278
279
280
281
282
283
            # Update min
            if scalar < self.stats[min_key]:
                self.stats[min_key] = scalar
            # Update max
            if scalar > self.stats[max_key]:
                self.stats[max_key] = scalar

284
285
            self.stats[counter_key] += 1
        except KeyError:
286
287
288
289
290
            self.stats[mean_key] = scalar
            self.stats[min_key] = scalar
            self.stats[max_key] = scalar
            self.stats[counter_key] = 1

291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
    def delete_key_in_running_stats(self, key):
        """
        This deletes a particular key in the running stats
        dictionary, if it exists
        """
        mean_key = "{}_mean".format(key)
        counter_key = "{}_counter".format(key)
        min_key = "{}_min".format(key)
        max_key = "{}_max".format(key)

        try:
            del mean_key
            del counter_key
            del min_key
            del max_key
        except KeyError:
            pass

309
    def get_env_filepaths(self):
310
311
312
313
314
        """
        Gathers a list of all available rail env files to be used
        for evaluation. The folder structure expected at the `test_env_folder`
        is similar to :

u214892's avatar
u214892 committed
315
316
317
318
319
320
321
322
323
324
325
            .
            ├── Test_0
            │   ├── Level_1.pkl
            │   ├── .......
            │   ├── .......
            │   └── Level_99.pkl
            └── Test_1
                ├── Level_1.pkl
                ├── .......
                ├── .......
                └── Level_99.pkl
u214892's avatar
u214892 committed
326
        """
327
328
329
330
        env_paths = glob.glob(
            os.path.join(
                self.test_env_folder,
                "*/*.pkl"
331
            )
332
333
        )

334
335
        # Remove the root folder name from the individual
        # lists, so that we only have the path relative
336
        # to the test root folder
337
338
339
340
341
342
        env_paths = [os.path.relpath(x, self.test_env_folder) for x in env_paths]

        # Sort in proper numerical order
        def get_file_order(filename):
            test_id, level_id = self.get_env_test_and_level(filename)
            value = test_id * 1000 + level_id
343
344
345
346
            return value

        env_paths.sort(key=get_file_order)

347
        # if requested, only generate actions for those envs which don't already have them
348
        if self.merge_dir and self.missing_only:
349
            existing_paths = (itertools.chain.from_iterable(
350
                [glob.glob(os.path.join(self.merge_dir, f"envs/*.{ext}"))
351
                 for ext in ["pkl", "mpk"]]))
352
            existing_paths = [os.path.relpath(sPath, self.merge_dir) for sPath in existing_paths]
353
            env_paths = set(env_paths) - set(existing_paths)
354

355
        return env_paths
MasterScrat's avatar
MasterScrat committed
356

357
358
359
360
361
362
363
364
365
366
367
    def get_env_test_and_level(self, filename):
        numbers = re.findall(r'\d+', os.path.relpath(filename))

        if len(numbers) == 2:
            test_id = int(numbers[0])
            level_id = int(numbers[1])
        else:
            print(numbers)
            raise ValueError("Unexpected file path, expects 'Test_<N>/Level_<M>.pkl', found", filename)
        return test_id, level_id

368
369
370
371
372
373
    def instantiate_evaluation_metadata(self):
        """
            This instantiates a pandas dataframe to record
            information specific to each of the individual env evaluations.

            This loads the template CSV with pre-filled information from the
374
            provided metadata.csv file, and fills it up with
375
376
377
378
            evaluation runtime information.
        """
        self.evaluation_metadata_df = None
        metadata_file_path = os.path.join(
MasterScrat's avatar
MasterScrat committed
379
380
381
            self.test_env_folder,
            "metadata.csv"
        )
382
383
        if os.path.exists(metadata_file_path):
            self.evaluation_metadata_df = pd.read_csv(metadata_file_path)
384
385
386
            self.evaluation_metadata_df["filename"] = \
                self.evaluation_metadata_df["test_id"] + \
                "/" + self.evaluation_metadata_df["env_id"] + ".pkl"
387
388
            self.evaluation_metadata_df = self.evaluation_metadata_df.set_index("filename")

389
            # Add custom columns for evaluation specific metrics
390
391
392
393
394
            self.evaluation_metadata_df["reward"] = np.nan
            self.evaluation_metadata_df["normalized_reward"] = np.nan
            self.evaluation_metadata_df["percentage_complete"] = np.nan
            self.evaluation_metadata_df["steps"] = np.nan
            self.evaluation_metadata_df["simulation_time"] = np.nan
395
            self.evaluation_metadata_df["nb_malfunctioning_trains"] = np.nan
396
397
            self.evaluation_metadata_df["nb_deadlocked_trains"] = np.nan

398
399
400
401
402
            # Add client specific columns
            # TODO: This needs refactoring
            self.evaluation_metadata_df["controller_inference_time_min"] = np.nan
            self.evaluation_metadata_df["controller_inference_time_mean"] = np.nan
            self.evaluation_metadata_df["controller_inference_time_max"] = np.nan
403
        else:
MasterScrat's avatar
MasterScrat committed
404
            raise Exception("metadata.csv not found in tests folder ({}). Please use an updated version of the test set.".format(metadata_file_path))
405
406
407
408

    def update_evaluation_metadata(self):
        """
        This function is called when we move from one simulation to another
409
        and it simply tries to update the simulation specific information
410
411
        for the **previous** episode in the metadata_df if it exists.
        """
MasterScrat's avatar
MasterScrat committed
412

413
        if self.evaluation_metadata_df is not None and len(self.simulation_env_file_paths) > 0:
414
415
416
417
418
419
            last_simulation_env_file_path = self.simulation_env_file_paths[-1]

            _row = self.evaluation_metadata_df.loc[
                last_simulation_env_file_path
            ]

420
            # Add controller_inference_time_metrics
421
            # These metrics may be missing if no step was done before the episode finished
422

423
424
425
            # generate the lists of names for the stats (input names and output names)
            sPrefixIn = "current_episode_controller_inference_time_"
            sPrefixOut = "controller_inference_time_"
MasterScrat's avatar
MasterScrat committed
426
427
            lsStatIn = [sPrefixIn + sStat for sStat in ["min", "mean", "max"]]
            lsStatOut = [sPrefixOut + sStat for sStat in ["min", "mean", "max"]]
428
429

            if lsStatIn[0] in self.stats:
MasterScrat's avatar
MasterScrat committed
430
                lrStats = [self.stats[sStat] for sStat in lsStatIn]
431
            else:
MasterScrat's avatar
MasterScrat committed
432
433
434
435
436
437
438
439
440
441
442
443
444
445
                lrStats = [0.0] * len(lsStatIn)

            lsFields = ("reward, normalized_reward, percentage_complete, " + \
                        "steps, simulation_time, nb_malfunctioning_trains, nb_deadlocked_trains").split(", ") + \
                       lsStatOut

            loValues = [self.simulation_rewards[-1],
                        self.simulation_rewards_normalized[-1],
                        self.simulation_percentage_complete[-1],
                        self.simulation_steps[-1],
                        self.simulation_times[-1],
                        self.nb_malfunctioning_trains[-1],
                        self.nb_deadlocked_trains[-1]
                        ] + lrStats
446
447
448
449
450

            # update the dataframe without the updating-a-copy warning
            df = self.evaluation_metadata_df
            df.loc[last_simulation_env_file_path, lsFields] = loValues

MasterScrat's avatar
MasterScrat committed
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
            # _row.reward = self.simulation_rewards[-1]
            # _row.normalized_reward = self.simulation_rewards_normalized[-1]
            # _row.percentage_complete = self.simulation_percentage_complete[-1]
            # _row.steps = self.simulation_steps[-1]
            # _row.simulation_time = self.simulation_times[-1]
            # _row.nb_malfunctioning_trains = self.nb_malfunctioning_trains[-1]

            # _row.controller_inference_time_min = self.stats[
            #    "current_episode_controller_inference_time_min"
            # ]
            # _row.controller_inference_time_mean = self.stats[
            #    "current_episode_controller_inference_time_mean"
            # ]
            # _row.controller_inference_time_max = self.stats[
            #    "current_episode_controller_inference_time_max"
            # ]
            # else:
468
469
470
471
            #    _row.controller_inference_time_min = 0.0
            #    _row.controller_inference_time_mean = 0.0
            #    _row.controller_inference_time_max = 0.0

MasterScrat's avatar
MasterScrat committed
472
            # self.evaluation_metadata_df.loc[
473
            #    last_simulation_env_file_path
MasterScrat's avatar
MasterScrat committed
474
            # ] = _row
475

476
            # Delete this key from the stats to ensure that it
477
478
479
480
            # gets computed again from scratch in the next episode
            self.delete_key_in_running_stats(
                "current_episode_controller_inference_time")

481
482
            if self.verbose:
                print(self.evaluation_metadata_df)
483
484

    def instantiate_redis_connection_pool(self):
485
        """
486
        Instantiates a Redis connection pool which can be used to
487
488
        communicate with the message broker
        """
489
490
        if self.verbose or self.report:
            print("Attempting to connect to redis server at {}:{}/{}".format(
u214892's avatar
u214892 committed
491
492
493
                self.remote_host,
                self.remote_port,
                self.remote_db))
494
495

        self.redis_pool = redis.ConnectionPool(
u214892's avatar
u214892 committed
496
497
498
499
500
            host=self.remote_host,
            port=self.remote_port,
            db=self.remote_db,
            password=self.remote_password
        )
501
        self.redis_conn = redis.Redis(connection_pool=self.redis_pool)
502
503

    def get_redis_connection(self):
504
505
506
507
        """
        Obtains a new redis connection from a previously instantiated
        redis connection pool
        """
508
        return self.redis_conn
509
510

    def _error_template(self, payload):
511
        """
512
        Simple helper function to pass a payload as a part of a
513
514
        flatland comms error template.
        """
515
516
517
518
519
520
        _response = {}
        _response['type'] = messages.FLATLAND_RL.ERROR
        _response['payload'] = payload
        return _response

    def get_next_command(self):
521
        """
522
523
524
        A helper function to obtain the next command, which transparently
        also deals with things like unpacking of the command from the
        packed message, and consider the timeouts, etc when trying to
525
526
        fetch a new command.
        """
MasterScrat's avatar
MasterScrat committed
527

528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
        COMMAND_TIMEOUT = DEFAULT_COMMAND_TIMEOUT
        """
        Handle case specific timeouts :
            - INTIAL_PLANNING_TIMEOUT
                The timeout between an env_create call and the first env_step call
            - PER_STEP_TIMEOUT
                The timeout between two consecutive env_step calls
        """
        if self.previous_command['type'] == messages.FLATLAND_RL.ENV_CREATE:
            """
            In case the previous command is an env_create, then leave 
            a but more time for the intial planning
            """
            COMMAND_TIMEOUT = INTIAL_PLANNING_TIMEOUT
        elif self.previous_command['type'] == messages.FLATLAND_RL.ENV_STEP:
            """
            Use the per_step_time for all timesteps between two env_step calls
            # Corner Case : 
                - Are there any reasons why a call between the last env_step call 
                and the subsequent env_create call will take an excessively large 
                amount of time (>5s in this case)
            """
            COMMAND_TIMEOUT = PER_STEP_TIMEOUT
        elif self.previous_command['type'] == messages.FLATLAND_RL.ENV_SUBMIT:
            """
            If the user has already done an env_submit call, then the timeout 
            can be an arbitrarily large number.
            """
MasterScrat's avatar
MasterScrat committed
556
            COMMAND_TIMEOUT = 10 ** 6
557

558
559
560
        if self.disable_timeouts:
            COMMAND_TIMEOUT = None

561
        @timeout_decorator.timeout(COMMAND_TIMEOUT, use_signals=use_signals_in_timeout)  # timeout for each command
562
563
564
565
566
567
568
569
570
571
        def _get_next_command(command_channel, _redis):
            """
            A low level wrapper for obtaining the next command from a
            pre-agreed command channel.
            At the momment, the communication protocol uses lpush for pushing
            in commands, and brpop for reading out commands.
            """
            command = _redis.brpop(command_channel)[1]
            return command

572
        # try:
573
        if True:
574
            _redis = self.get_redis_connection()
575
            command = _get_next_command(self.command_channel, _redis)
576
577
            if self.verbose or self.report:
                print("Command Service: ", command)
578
579
580
581
582
583
584

        if self.use_pickle:
            command = pickle.loads(command)
        else:
            command = msgpack.unpackb(
                command,
                object_hook=m.decode,
585
586
                strict_map_key=False,  # msgpack 1.0
                encoding="utf8"  # msgpack 1.0
587
            )
588
589
        if self.verbose:
            print("Received Request : ", command)
590

591
        message_queue_latency = time.time() - command["timestamp"]
592
        self.update_running_stats("message_queue_latency", message_queue_latency)
593
594
        return command

595
    def send_response(self, _command_response, command, suppress_logs=False):
596
597
598
        _redis = self.get_redis_connection()
        command_response_channel = command['response_channel']

599
        if self.verbose and not suppress_logs:
600
            print("Responding with : ", _command_response)
u214892's avatar
u214892 committed
601

602
603
604
605
        if self.use_pickle:
            sResponse = pickle.dumps(_command_response)
        else:
            sResponse = msgpack.packb(
u214892's avatar
u214892 committed
606
607
                _command_response,
                default=m.encode,
608
                use_bin_type=True)
609
610
611
        _redis.rpush(command_response_channel, sResponse)

    def send_error(self, error_dict, suppress_logs=False):
MasterScrat's avatar
MasterScrat committed
612
        """ For out-of-band errors like timeouts,
613
614
615
            where we do not have a command, so we have no response channel!
        """
        _redis = self.get_redis_connection()
616
        print("Sending error : ", error_dict)
617
618
619
620
621
622
623
624
625
626

        if self.use_pickle:
            sResponse = pickle.dumps(error_dict)
        else:
            sResponse = msgpack.packb(
                error_dict,
                default=m.encode,
                use_bin_type=True)

        _redis.rpush(self.error_channel, sResponse)
u214892's avatar
u214892 committed
627

628
629
630
631
    def handle_ping(self, command):
        """
        Handles PING command from the client.
        """
632
633
634
635
636
637
638
        service_version = flatland.__version__
        if "version" in command["payload"].keys():
            client_version = command["payload"]["version"]
        else:
            # 2.1.4 -> when the version mismatch check was added
            client_version = "2.1.4"

639
640
641
        _command_response = {}
        _command_response['type'] = messages.FLATLAND_RL.PONG
        _command_response['payload'] = {}
spmohanty's avatar
spmohanty committed
642
        if client_version not in SUPPORTED_CLIENT_VERSIONS:
643
644
645
646
647
648
649
            _command_response['type'] = messages.FLATLAND_RL.ERROR
            _command_response['payload']['message'] = \
                "Client-Server Version Mismatch => " + \
                "[ Client Version : {} ] ".format(client_version) + \
                "[ Server Version : {} ] ".format(service_version)
            self.send_response(_command_response, command)
            raise Exception(_command_response['payload']['message'])
650

651
        self.send_response(_command_response, command)
652
653

    def handle_env_create(self, command):
654
655
656
        """
        Handles a ENV_CREATE command from the client
        """
657

nimishsantosh107's avatar
nimishsantosh107 committed
658
659
        print(" -- [DEBUG] [env_create] EVAL DONE: ",self.evaluation_done)

660
661
        # Check if the previous episode was finished
        if not self.simulation_done and not self.evaluation_done:
662
663
664
665
            _command_response = self._error_template("CAN'T CREATE NEW ENV BEFORE PREVIOUS IS DONE")
            self.send_response(_command_response, command)
            raise Exception(_command_response['payload'])

666
        self.simulation_count += 1
667
        self.simulation_done = False
668

669
670
671
672
        if self.simulation_count == 0:
            # Very first episode: start the overall timer
            self.overall_start_time = time.time()

673
674
675
        # reset the timeout flag / state.
        self.state_env_timed_out = False

676
        # Check if we have finished all the available envs
nimishsantosh107's avatar
nimishsantosh107 committed
677
678
        print(" -- [DEBUG] [env_create] SIM COUNT: ", self.simulation_count + 1, len(self.env_file_paths))
        
679
680
681
        if self.simulation_count >= len(self.env_file_paths):
            self.evaluation_done = True
            # Hack - just ensure these are set
MasterScrat's avatar
MasterScrat committed
682
            test_env_file_path = self.env_file_paths[self.simulation_count - 1]
683
684
685
686
            env_test, env_level = self.get_env_test_and_level(test_env_file_path)
        else:
            test_env_file_path = self.env_file_paths[self.simulation_count]
            env_test, env_level = self.get_env_test_and_level(test_env_file_path)
687
688
689
690
691
692
693
694
695
696
697

        # Did we just finish a test, and if yes did it reach high enough mean percentage done?
        if self.current_test != env_test and env_test != 0:
            if self.current_test not in self.simulation_percentage_complete_per_test:
                print("No environment was finished at all during test {}!".format(self.current_test))
                mean_test_complete_percentage = 0.0
            else:
                mean_test_complete_percentage = np.mean(self.simulation_percentage_complete_per_test[self.current_test])

            if mean_test_complete_percentage < TEST_MIN_PERCENTAGE_COMPLETE_MEAN:
                print("=" * 15)
MasterScrat's avatar
MasterScrat committed
698
699
                msg = "The mean percentage of done agents during the last Test ({} environments) was too low: {:.3f} < {}".format(
                    len(self.simulation_percentage_complete_per_test[self.current_test]),
700
701
702
703
704
705
706
707
                    mean_test_complete_percentage,
                    TEST_MIN_PERCENTAGE_COMPLETE_MEAN
                )
                print(msg, "Evaluation will stop.")
                self.termination_cause = msg
                self.evaluation_done = True

        if self.simulation_count < len(self.env_file_paths) and not self.evaluation_done:
708
709
710
            """
            There are still test envs left that are yet to be evaluated 
            """
711

MasterScrat's avatar
MasterScrat committed
712
            print("=" * 15)
713
            print("Evaluating {} ({}/{})".format(test_env_file_path, self.simulation_count+1, len(self.env_file_paths)))
714

715
716
717
718
            test_env_file_path = os.path.join(
                self.test_env_folder,
                test_env_file_path
            )
719
720
721
722

            self.current_test = env_test
            self.current_level = env_level

723
            del self.env
724
725

            self.env, _env_dict = RailEnvPersister.load_new(test_env_file_path)
726
            # distance map here?
727

728
729
            self.begin_simulation = time.time()

730
731
732
733
734
735
736
737
738
739
            # Update evaluation metadata for the previous episode
            self.update_evaluation_metadata()

            # Start adding placeholders for the new episode
            self.simulation_env_file_paths.append(
                os.path.relpath(
                    test_env_file_path,
                    self.test_env_folder
                ))  # relative path

740
            self.simulation_rewards.append(0)
741
            self.simulation_rewards_normalized.append(0)
742
            self.simulation_percentage_complete.append(0)
743
            self.simulation_times.append(0)
744
            self.simulation_steps.append(0)
745
            self.nb_malfunctioning_trains.append(0)
746

747
            self.current_step = 0
748

749
            _observation, _info = self.env.reset(
Erik Nygren's avatar
Erik Nygren committed
750
751
752
753
                regenerate_rail=True,
                regenerate_schedule=True,
                random_seed=RANDOM_SEED
            )
754

755
            if self.visualize:
756
757
758
759
760
                current_env_path = self.env_file_paths[self.simulation_count]
                if current_env_path in self.video_generation_envs:
                    self.env_renderer = RenderTool(self.env, gl="PILSVG", )
                elif self.env_renderer:
                    self.env_renderer = False
761

762
763
764
765
            _command_response = {}
            _command_response['type'] = messages.FLATLAND_RL.ENV_CREATE_RESPONSE
            _command_response['payload'] = {}
            _command_response['payload']['observation'] = _observation
766
            _command_response['payload']['env_file_path'] = self.env_file_paths[self.simulation_count]
767
768
            _command_response['payload']['info'] = _info
            _command_response['payload']['random_seed'] = RANDOM_SEED
769
        else:
nimishsantosh107's avatar
nimishsantosh107 committed
770
            print(" -- [DEBUG] [env_create] return obs = False (END)")
771
772
773
774
            """
            All test env evaluations are complete
            """
            _command_response = {}
775
            _command_response['type'] = messages.FLATLAND_RL.ENV_CREATE_RESPONSE
776
777
            _command_response['payload'] = {}
            _command_response['payload']['observation'] = False
u214892's avatar
u214892 committed
778
            _command_response['payload']['env_file_path'] = False
779
            _command_response['payload']['info'] = False
780
            _command_response['payload']['random_seed'] = False
781
782

        self.send_response(_command_response, command)
783
784
785
        #####################################################################
        # Update evaluation state
        #####################################################################
786
        elapsed = time.time() - self.overall_start_time
787
        progress = np.clip(
788
            elapsed / OVERALL_TIMEOUT,
u214892's avatar
u214892 committed
789
            0, 1)
spmohanty's avatar
spmohanty committed
790

791
        mean_reward, mean_normalized_reward, sum_normalized_reward, mean_percentage_complete = self.compute_mean_scores()
spmohanty's avatar
spmohanty committed
792

793
794
795
        self.evaluation_state["state"] = "IN_PROGRESS"
        self.evaluation_state["progress"] = progress
        self.evaluation_state["simulation_count"] = self.simulation_count
796
797
        self.evaluation_state["score"]["score"] = sum_normalized_reward
        self.evaluation_state["score"]["score_secondary"] = mean_percentage_complete
798
        self.evaluation_state["meta"]["normalized_reward"] = mean_normalized_reward
799
        self.evaluation_state["meta"]["termination_cause"] = self.termination_cause
800
        self.handle_aicrowd_info_event(self.evaluation_state)
801
802

        self.episode_actions = []
803

804
    def handle_env_step(self, command):
805
806
        """
        Handles a ENV_STEP command from the client
u214892's avatar
u214892 committed
807
        TODO: Add a high level summary of everything thats happening here.
808
        """
809

810
811
        if self.state_env_timed_out or self.evaluation_done:
            print("Ignoring step command after timeout.")
812
813
            return

814
815
        _payload = command['payload']

816
        if not self.env:
817
            raise Exception("env_client.step called before env_client.env_create() call")
818
        if self.env.dones['__all__']:
819
820
821
            raise Exception(
                "Client attempted to perform an action on an Env which \
                has done['__all__']==True")
822

823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
        overall_elapsed = (time.time() - self.overall_start_time)
        if overall_elapsed > OVERALL_TIMEOUT:
            msg = "Reached overall time limit: took {:.2f}s, limit is {:.2f}s.".format(
                overall_elapsed, OVERALL_TIMEOUT
            )
            self.termination_cause = msg
            self.evaluation_done = True

            print("=" * 15)
            print(msg, "Evaluation will stop.")
            return
        # else:
        #     print("="*15)
        #     print("{}s left!".format(OVERALL_TIMEOUT - overall_elapsed))

838
        action = _payload['action']
839
        inference_time = _payload['inference_time']
840
841
842
843
        # We record this metric in two keys:
        #   - One for the current episode
        #   - One global
        self.update_running_stats("current_episode_controller_inference_time", inference_time)
844
845
        self.update_running_stats("controller_inference_time", inference_time)

846
        # Perform the step
847
        time_start = time.time()
848
        _observation, all_rewards, done, info = self.env.step(action)
849
        time_diff = time.time() - time_start
850
        self.update_running_stats("internal_env_step_time", time_diff)
851

MasterScrat's avatar
MasterScrat committed
852
853
        self.current_step += 1

854
        cumulative_reward = sum(all_rewards.values())
855
        self.simulation_rewards[-1] += cumulative_reward
856
        self.simulation_steps[-1] += 1
857
858
859
860
861
862
863
        """
        The normalized rewards normalize the reward for an 
        episode by dividing the whole reward by max-time-steps 
        allowed in that episode, and the number of agents present in 
        that episode
        """
        self.simulation_rewards_normalized[-1] += \
864
            (cumulative_reward / (
MasterScrat's avatar
MasterScrat committed
865
                self.env._max_episode_steps *
u214892's avatar
u214892 committed
866
                self.env.get_num_agents()
867
            ))
868

869
870
871
872
873
        # We count the number of agents that malfunctioned by checking how many have 1 more steps left before recovery
        num_malfunctioning = sum(agent.malfunction_data['malfunction'] == 1 for agent in self.env.agents)

        if self.verbose and num_malfunctioning > 0:
            print("Step {}: {} agents have malfunctioned and will recover next step".format(self.current_step, num_malfunctioning))
874
875
876

        self.nb_malfunctioning_trains[-1] += num_malfunctioning

877
        # record the actions before checking for done
878
879
        if self.action_dir is not None:
            self.episode_actions.append(action)
880

881
        # Is the episode over?
882
        if done["__all__"]:
883
884
            self.simulation_done = True

885
886
887
            if self.begin_simulation:
                # If begin simulation has already been initialized at least once
                # This adds the simulation time for the previous episode
888
                self.simulation_times[-1] = time.time() - self.begin_simulation
889

890
891
892
893
            # Compute percentage complete
            complete = 0
            for i_agent in range(self.env.get_num_agents()):
                agent = self.env.agents[i_agent]
894
                if agent.state == TrainState.DONE:
895
896
897
                    complete += 1
            percentage_complete = complete * 1.0 / self.env.get_num_agents()
            self.simulation_percentage_complete[-1] = percentage_complete
u214892's avatar
u214892 committed
898

899
900
901
902
903
904
905
906
            # adds 1.0 so we can add them up
            self.simulation_rewards_normalized[-1] += 1.0

            if self.current_test not in self.simulation_percentage_complete_per_test:
                self.simulation_percentage_complete_per_test[self.current_test] = []
            self.simulation_percentage_complete_per_test[self.current_test].append(percentage_complete)
            print("Percentage for test {}, level {}: {}".format(self.current_test, self.current_level, percentage_complete))

907
908
            if len(self.env.cur_episode) > 0:
                g3Ep = np.array(self.env.cur_episode)
MasterScrat's avatar
MasterScrat committed
909
                self.nb_deadlocked_trains.append(np.sum(g3Ep[-1, :, 5]))
910
911
912
            else:
                self.nb_deadlocked_trains.append(np.nan)

913
914
915
916
917
918
            print(
                "Evaluation finished in {} timesteps, {:.3f} seconds. Percentage agents done: {:.3f}. Normalized reward: {:.3f}. Number of malfunctions: {}.".format(
                    self.simulation_steps[-1],
                    self.simulation_times[-1],
                    self.simulation_percentage_complete[-1],
                    self.simulation_rewards_normalized[-1],
919
920
                    self.nb_malfunctioning_trains[-1],
                    self.nb_deadlocked_trains[-1]
921
922
923
                ))

            print("Total normalized reward so far: {:.3f}".format(sum(self.simulation_rewards_normalized)))
924

925
926
927
928
929
            # Write intermediate results
            if self.result_output_path:
                self.evaluation_metadata_df.to_csv(self.result_output_path)
                print("Wrote intermediate output results to : {}".format(self.result_output_path))

930
            if self.action_dir is not None:
931
                self.save_actions()
MasterScrat's avatar
MasterScrat committed
932

933
            if self.episode_dir is not None:
934
935
                self.save_episode()

936
            if self.merge_dir is not None:
937
938
                self.save_merged_env()

939
940
        # Record Frame
        if self.visualize:
941
            """
942
            Only generate and save the frames for environments which are separately provided
943
944
            in video_generation_indices param
            """
945
946
            current_env_path = self.env_file_paths[self.simulation_count]
            if current_env_path in self.video_generation_envs:
947
948
949
                self.env_renderer.render_env(
                    show=False,
                    show_observations=False,
950
951
                    show_predictions=False,
                    show_rowcols=False
952
953
                )

954
                self.env_renderer.gl.save_image(
u214892's avatar
u214892 committed
955
956
957
958
                    os.path.join(
                        self.vizualization_folder_name,
                        "flatland_frame_{:04d}.png".format(self.record_frame_step)
                    ))
959
                self.record_frame_step += 1
960

961
962
    def save_actions(self):
        sfEnv = self.env_file_paths[self.simulation_count]
MasterScrat's avatar
MasterScrat committed
963

964
        sfActions = self.action_dir + "/" + sfEnv.replace(".pkl", ".json")
965
966
967
968
969

        print("env path: ", sfEnv, " sfActions:", sfActions)

        if not os.path.exists(os.path.dirname(sfActions)):
            os.makedirs(os.path.dirname(sfActions))
MasterScrat's avatar
MasterScrat committed
970

971
        with open(sfActions, "w") as fOut:
972
            json.dump(self.episode_actions, fOut)
973

974
        self.episode_actions = []
MasterScrat's avatar
MasterScrat committed
975

976
977
    def save_episode(self):
        sfEnv = self.env_file_paths[self.simulation_count]
978
        sfEpisode = self.episode_dir + "/" + sfEnv
979
980
        print("env path: ", sfEnv, " sfEpisode:", sfEpisode)
        RailEnvPersister.save_episode(self.env, sfEpisode)
981
        # self.env.save_episode(sfEpisode)