service.py 56.1 KB
Newer Older
1
2
#!/usr/bin/env python
from __future__ import print_function
u214892's avatar
u214892 committed
3

4
import glob
u214892's avatar
u214892 committed
5
6
import os
import random
7
import shutil
8
import time
9
import traceback
10
11
import json
import itertools
12
import re
u214892's avatar
u214892 committed
13

14
import crowdai_api
u214892's avatar
u214892 committed
15
16
import msgpack
import msgpack_numpy as m
17
import pickle
u214892's avatar
u214892 committed
18
import numpy as np
19
import pandas as pd
u214892's avatar
u214892 committed
20
import redis
21
22
import timeout_decorator

Erik Nygren's avatar
Erik Nygren committed
23
import flatland
u214892's avatar
u214892 committed
24
from flatland.core.env_observation_builder import DummyObservationBuilder
25
from flatland.envs.agent_utils import RailAgentStatus
26
from flatland.envs.malfunction_generators import malfunction_from_file
Erik Nygren's avatar
Erik Nygren committed
27
from flatland.envs.rail_env import RailEnv
u214892's avatar
u214892 committed
28
from flatland.envs.rail_generators import rail_from_file
29
from flatland.envs.line_generators import line_from_file
u214892's avatar
u214892 committed
30
31
32
from flatland.evaluators import aicrowd_helpers
from flatland.evaluators import messages
from flatland.utils.rendertools import RenderTool
33
34
from flatland.envs.rail_env_utils import load_flatland_environment_from_file
from flatland.envs.persistence import RailEnvPersister
35
36
37
38
39
40
41
42
43
44
45

use_signals_in_timeout = True
if os.name == 'nt':
    """
    Windows doesnt support signals, hence
    timeout_decorators usually fall apart.
    Hence forcing them to not using signals 
    whenever using the timeout decorator.
    """
    use_signals_in_timeout = False

46
m.patch()
47

48
49
50
########################################################
# CONSTANTS
########################################################
51

MasterScrat's avatar
MasterScrat committed
52
53
# Don't proceed to next Test if the previous one didn't reach this mean completion percentage
TEST_MIN_PERCENTAGE_COMPLETE_MEAN = float(os.getenv("TEST_MIN_PERCENTAGE_COMPLETE_MEAN", 0.25))
54
55
56

# After this number of consecutive timeouts, kill the submission:
# this probably means the submission has crashed
57
MAX_SUCCESSIVE_TIMEOUTS = int(os.getenv("FLATLAND_MAX_SUCCESSIVE_TIMEOUTS", 10))
58
59
60
61
62
63
64
65
66
67
68
69
70

debug_mode = (os.getenv("AICROWD_DEBUG_SUBMISSION", 0) == 1)
if debug_mode:
    print("=" * 20)
    print("Submission in DEBUG MODE! will get limited time")
    print("=" * 20)

# 8 hours (will get debug timeout from env variable if applicable)
OVERALL_TIMEOUT = int(os.getenv(
    "FLATLAND_OVERALL_TIMEOUT",
    8 * 60 * 60))

# 10 mins
71
INTIAL_PLANNING_TIMEOUT = int(os.getenv(
MasterScrat's avatar
MasterScrat committed
72
    "FLATLAND_INITIAL_PLANNING_TIMEOUT",
73
74
75
    10 * 60))

# 10 seconds
76
PER_STEP_TIMEOUT = int(os.getenv(
MasterScrat's avatar
MasterScrat committed
77
    "FLATLAND_PER_STEP_TIMEOUT",
78
79
80
    10))

# 5 min - applies to the rest of the commands
81
DEFAULT_COMMAND_TIMEOUT = int(os.getenv(
MasterScrat's avatar
MasterScrat committed
82
    "FLATLAND_DEFAULT_COMMAND_TIMEOUT",
83
    5 * 60))
84

85
RANDOM_SEED = int(os.getenv("FLATLAND_EVALUATION_RANDOM_SEED", 1001))
86

87
88
89
90
SUPPORTED_CLIENT_VERSIONS = \
    [
        flatland.__version__
    ]
91
92
93


class FlatlandRemoteEvaluationService:
94
95
96
    """
    A remote evaluation service which exposes the following interfaces
    of a RailEnv :
u214892's avatar
u214892 committed
97
98
    - env_create
    - env_step
99
    and an additional `env_submit` to cater to score computation and on-episode-complete post-processings.
100

101
    This service is designed to be used in conjunction with
102
    `FlatlandRemoteClient` and both the service and client maintain a
103
    local instance of the RailEnv instance, and in case of any unexpected
104
105
    divergences in the state of both the instances, the local RailEnv
    instance of the `FlatlandRemoteEvaluationService` is supposed to act
106
107
    as the single source of truth.

108
109
    Both the client and remote service communicate with each other
    via Redis as a message broker. The individual messages are packed and
110
111
112
    unpacked with `msgpack` (a patched version of msgpack which also supports
    numpy arrays).
    """
u214892's avatar
u214892 committed
113

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    def __init__(
        self,
        test_env_folder="/tmp",
        flatland_rl_service_id='FLATLAND_RL_SERVICE_ID',
        remote_host='127.0.0.1',
        remote_port=6379,
        remote_db=0,
        remote_password=None,
        visualize=False,
        video_generation_envs=[],
        report=None,
        verbose=False,
        action_dir=None,
        episode_dir=None,
        merge_dir=None,
        use_pickle=False,
        shuffle=False,
        missing_only=False,
        result_output_path=None,
        disable_timeouts=False
    ):
135

136
        # Episode recording properties
137
138
139
140
141
142
143
144
145
        self.action_dir = action_dir
        if action_dir and not os.path.exists(self.action_dir):
            os.makedirs(self.action_dir)
        self.episode_dir = episode_dir
        if episode_dir and not os.path.exists(self.episode_dir):
            os.makedirs(self.episode_dir)
        self.merge_dir = merge_dir
        if merge_dir and not os.path.exists(self.merge_dir):
            os.makedirs(self.merge_dir)
146
147
        self.use_pickle = use_pickle
        self.missing_only = missing_only
148
        self.episode_actions = []
149

150
        self.disable_timeouts = disable_timeouts
151
152
153
154
155
        if self.disable_timeouts:
            print("=" * 20)
            print("Timeout are DISABLED!")
            print("=" * 20)

156
        if shuffle:
157
            print("=" * 20)
158
            print("Env shuffling is ENABLED! not suitable for infinite wave")
159
            print("=" * 20)
160

MasterScrat's avatar
MasterScrat committed
161
162
163
164
165
166
167
168
        print("=" * 20)
        print("Max pre-planning time:", INTIAL_PLANNING_TIMEOUT)
        print("Max step time:", PER_STEP_TIMEOUT)
        print("Max overall time:", OVERALL_TIMEOUT)
        print("Max submission startup time:", DEFAULT_COMMAND_TIMEOUT)
        print("Max consecutive timeouts:", MAX_SUCCESSIVE_TIMEOUTS)
        print("=" * 20)

169
170
        # Test Env folder Paths
        self.test_env_folder = test_env_folder
171
        self.video_generation_envs = video_generation_envs
172
        self.env_file_paths = self.get_env_filepaths()
173
        print(self.env_file_paths)
174
175
        # Shuffle all the env_file_paths for more exciting videos
        # and for more uniform time progression
176
177
        if shuffle:
            random.shuffle(self.env_file_paths)
spmohanty's avatar
spmohanty committed
178
        print(self.env_file_paths)
179

180
        self.instantiate_evaluation_metadata()
181
182
183
184

        # Logging and Reporting related vars
        self.verbose = verbose
        self.report = report
u214892's avatar
u214892 committed
185

186
187
188
        # Use a state to swallow and ignore any steps after an env times out.
        self.state_env_timed_out = False

189
190
191
192
        # Count the number of successive timeouts (will kill after MAX_SUCCESSIVE_TIMEOUTS)
        # This prevents a crashed submission to keep running forever
        self.timeout_counter = 0

193
        # Results are the metrics: percent done, rewards, timing...
194
195
        self.result_output_path = result_output_path

196
197
198
199
        # Communication Protocol Related vars
        self.namespace = "flatland-rl"
        self.service_id = flatland_rl_service_id
        self.command_channel = "{}::{}::commands".format(
u214892's avatar
u214892 committed
200
201
202
            self.namespace,
            self.service_id
        )
203
204
205
206
207
        self.error_channel = "{}::{}::errors".format(
            self.namespace,
            self.service_id
        )

208
209
210
211
212
213
        # Message Broker related vars
        self.remote_host = remote_host
        self.remote_port = remote_port
        self.remote_db = remote_db
        self.remote_password = remote_password
        self.instantiate_redis_connection_pool()
214
215
216
217
218
219
220
221
222
223
224

        # AIcrowd evaluation specific vars
        self.oracle_events = crowdai_api.events.CrowdAIEvents(with_oracle=True)
        self.evaluation_state = {
            "state": "PENDING",
            "progress": 0.0,
            "simulation_count": 0,
            "total_simulation_count": len(self.env_file_paths),
            "score": {
                "score": 0.0,
                "score_secondary": 0.0
225
            },
spmohanty's avatar
spmohanty committed
226
227
            "meta": {
                "normalized_reward": 0.0
228
229
            }
        }
230
        self.stats = {}
231
232
233
        self.previous_command = {
            "type": None
        }
u214892's avatar
u214892 committed
234

235
236
        # RailEnv specific variables
        self.env = False
237
        self.env_renderer = False
238
        self.reward = 0
239
        self.simulation_done = True
240
        self.simulation_count = -1
241
        self.simulation_env_file_paths = []
242
        self.simulation_rewards = []
243
        self.simulation_rewards_normalized = []
244
        self.simulation_percentage_complete = []
245
        self.simulation_percentage_complete_per_test = {}
246
        self.simulation_steps = []
247
        self.simulation_times = []
248
        self.env_step_times = []
249
        self.nb_malfunctioning_trains = []
250
        self.nb_deadlocked_trains = []
251
252
253
        self.overall_start_time = 0
        self.termination_cause = "No reported termination cause."
        self.evaluation_done = False
254
255
        self.begin_simulation = False
        self.current_step = 0
256
257
        self.current_test = -1
        self.current_level = -1
258
        self.visualize = visualize
259
260
261
262
        self.vizualization_folder_name = "./.visualizations"
        self.record_frame_step = 0

        if self.visualize:
263
264
265
266
            if os.path.exists(self.vizualization_folder_name):
                print("[WARNING] Deleting already existing visualizations folder at : {}".format(
                    self.vizualization_folder_name
                ))
267
268
                shutil.rmtree(self.vizualization_folder_name)
            os.mkdir(self.vizualization_folder_name)
269

270
    def update_running_stats(self, key, scalar):
271
        """
272
        Computes the running min/mean/max for given param
273
274
275
        """
        mean_key = "{}_mean".format(key)
        counter_key = "{}_counter".format(key)
276
277
        min_key = "{}_min".format(key)
        max_key = "{}_max".format(key)
278
279

        try:
280
            # Update Mean
281
282
            self.stats[mean_key] = \
                ((self.stats[mean_key] * self.stats[counter_key]) + scalar) / (self.stats[counter_key] + 1)
283
284
285
286
287
288
289
            # Update min
            if scalar < self.stats[min_key]:
                self.stats[min_key] = scalar
            # Update max
            if scalar > self.stats[max_key]:
                self.stats[max_key] = scalar

290
291
            self.stats[counter_key] += 1
        except KeyError:
292
293
294
295
296
            self.stats[mean_key] = scalar
            self.stats[min_key] = scalar
            self.stats[max_key] = scalar
            self.stats[counter_key] = 1

297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
    def delete_key_in_running_stats(self, key):
        """
        This deletes a particular key in the running stats
        dictionary, if it exists
        """
        mean_key = "{}_mean".format(key)
        counter_key = "{}_counter".format(key)
        min_key = "{}_min".format(key)
        max_key = "{}_max".format(key)

        try:
            del mean_key
            del counter_key
            del min_key
            del max_key
        except KeyError:
            pass

315
    def get_env_filepaths(self):
316
317
318
319
320
        """
        Gathers a list of all available rail env files to be used
        for evaluation. The folder structure expected at the `test_env_folder`
        is similar to :

u214892's avatar
u214892 committed
321
322
323
324
325
326
327
328
329
330
331
            .
            ├── Test_0
            │   ├── Level_1.pkl
            │   ├── .......
            │   ├── .......
            │   └── Level_99.pkl
            └── Test_1
                ├── Level_1.pkl
                ├── .......
                ├── .......
                └── Level_99.pkl
u214892's avatar
u214892 committed
332
        """
333
334
335
336
        env_paths = glob.glob(
            os.path.join(
                self.test_env_folder,
                "*/*.pkl"
337
            )
338
339
        )

340
341
        # Remove the root folder name from the individual
        # lists, so that we only have the path relative
342
        # to the test root folder
343
344
345
346
347
348
        env_paths = [os.path.relpath(x, self.test_env_folder) for x in env_paths]

        # Sort in proper numerical order
        def get_file_order(filename):
            test_id, level_id = self.get_env_test_and_level(filename)
            value = test_id * 1000 + level_id
349
350
351
352
            return value

        env_paths.sort(key=get_file_order)

353
        # if requested, only generate actions for those envs which don't already have them
354
        if self.merge_dir and self.missing_only:
355
            existing_paths = (itertools.chain.from_iterable(
356
                [glob.glob(os.path.join(self.merge_dir, f"envs/*.{ext}"))
357
                 for ext in ["pkl", "mpk"]]))
358
            existing_paths = [os.path.relpath(sPath, self.merge_dir) for sPath in existing_paths]
359
            env_paths = set(env_paths) - set(existing_paths)
360

361
        return env_paths
MasterScrat's avatar
MasterScrat committed
362

363
364
365
366
367
368
369
370
371
372
373
    def get_env_test_and_level(self, filename):
        numbers = re.findall(r'\d+', os.path.relpath(filename))

        if len(numbers) == 2:
            test_id = int(numbers[0])
            level_id = int(numbers[1])
        else:
            print(numbers)
            raise ValueError("Unexpected file path, expects 'Test_<N>/Level_<M>.pkl', found", filename)
        return test_id, level_id

374
375
376
377
378
379
    def instantiate_evaluation_metadata(self):
        """
            This instantiates a pandas dataframe to record
            information specific to each of the individual env evaluations.

            This loads the template CSV with pre-filled information from the
380
            provided metadata.csv file, and fills it up with
381
382
383
384
            evaluation runtime information.
        """
        self.evaluation_metadata_df = None
        metadata_file_path = os.path.join(
MasterScrat's avatar
MasterScrat committed
385
386
387
            self.test_env_folder,
            "metadata.csv"
        )
388
389
        if os.path.exists(metadata_file_path):
            self.evaluation_metadata_df = pd.read_csv(metadata_file_path)
390
391
392
            self.evaluation_metadata_df["filename"] = \
                self.evaluation_metadata_df["test_id"] + \
                "/" + self.evaluation_metadata_df["env_id"] + ".pkl"
393
394
            self.evaluation_metadata_df = self.evaluation_metadata_df.set_index("filename")

395
            # Add custom columns for evaluation specific metrics
396
397
398
399
400
            self.evaluation_metadata_df["reward"] = np.nan
            self.evaluation_metadata_df["normalized_reward"] = np.nan
            self.evaluation_metadata_df["percentage_complete"] = np.nan
            self.evaluation_metadata_df["steps"] = np.nan
            self.evaluation_metadata_df["simulation_time"] = np.nan
401
            self.evaluation_metadata_df["nb_malfunctioning_trains"] = np.nan
402
403
            self.evaluation_metadata_df["nb_deadlocked_trains"] = np.nan

404
405
406
407
408
            # Add client specific columns
            # TODO: This needs refactoring
            self.evaluation_metadata_df["controller_inference_time_min"] = np.nan
            self.evaluation_metadata_df["controller_inference_time_mean"] = np.nan
            self.evaluation_metadata_df["controller_inference_time_max"] = np.nan
409
        else:
MasterScrat's avatar
MasterScrat committed
410
            raise Exception("metadata.csv not found in tests folder ({}). Please use an updated version of the test set.".format(metadata_file_path))
411
412
413
414

    def update_evaluation_metadata(self):
        """
        This function is called when we move from one simulation to another
415
        and it simply tries to update the simulation specific information
416
417
        for the **previous** episode in the metadata_df if it exists.
        """
MasterScrat's avatar
MasterScrat committed
418

419
        if self.evaluation_metadata_df is not None and len(self.simulation_env_file_paths) > 0:
420
421
422
423
424
425
            last_simulation_env_file_path = self.simulation_env_file_paths[-1]

            _row = self.evaluation_metadata_df.loc[
                last_simulation_env_file_path
            ]

426
            # Add controller_inference_time_metrics
427
            # These metrics may be missing if no step was done before the episode finished
428

429
430
431
            # generate the lists of names for the stats (input names and output names)
            sPrefixIn = "current_episode_controller_inference_time_"
            sPrefixOut = "controller_inference_time_"
MasterScrat's avatar
MasterScrat committed
432
433
            lsStatIn = [sPrefixIn + sStat for sStat in ["min", "mean", "max"]]
            lsStatOut = [sPrefixOut + sStat for sStat in ["min", "mean", "max"]]
434
435

            if lsStatIn[0] in self.stats:
MasterScrat's avatar
MasterScrat committed
436
                lrStats = [self.stats[sStat] for sStat in lsStatIn]
437
            else:
MasterScrat's avatar
MasterScrat committed
438
439
440
441
442
443
444
445
446
447
448
449
450
451
                lrStats = [0.0] * len(lsStatIn)

            lsFields = ("reward, normalized_reward, percentage_complete, " + \
                        "steps, simulation_time, nb_malfunctioning_trains, nb_deadlocked_trains").split(", ") + \
                       lsStatOut

            loValues = [self.simulation_rewards[-1],
                        self.simulation_rewards_normalized[-1],
                        self.simulation_percentage_complete[-1],
                        self.simulation_steps[-1],
                        self.simulation_times[-1],
                        self.nb_malfunctioning_trains[-1],
                        self.nb_deadlocked_trains[-1]
                        ] + lrStats
452
453
454
455
456

            # update the dataframe without the updating-a-copy warning
            df = self.evaluation_metadata_df
            df.loc[last_simulation_env_file_path, lsFields] = loValues

MasterScrat's avatar
MasterScrat committed
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
            # _row.reward = self.simulation_rewards[-1]
            # _row.normalized_reward = self.simulation_rewards_normalized[-1]
            # _row.percentage_complete = self.simulation_percentage_complete[-1]
            # _row.steps = self.simulation_steps[-1]
            # _row.simulation_time = self.simulation_times[-1]
            # _row.nb_malfunctioning_trains = self.nb_malfunctioning_trains[-1]

            # _row.controller_inference_time_min = self.stats[
            #    "current_episode_controller_inference_time_min"
            # ]
            # _row.controller_inference_time_mean = self.stats[
            #    "current_episode_controller_inference_time_mean"
            # ]
            # _row.controller_inference_time_max = self.stats[
            #    "current_episode_controller_inference_time_max"
            # ]
            # else:
474
475
476
477
            #    _row.controller_inference_time_min = 0.0
            #    _row.controller_inference_time_mean = 0.0
            #    _row.controller_inference_time_max = 0.0

MasterScrat's avatar
MasterScrat committed
478
            # self.evaluation_metadata_df.loc[
479
            #    last_simulation_env_file_path
MasterScrat's avatar
MasterScrat committed
480
            # ] = _row
481

482
            # Delete this key from the stats to ensure that it
483
484
485
486
            # gets computed again from scratch in the next episode
            self.delete_key_in_running_stats(
                "current_episode_controller_inference_time")

487
488
            if self.verbose:
                print(self.evaluation_metadata_df)
489
490

    def instantiate_redis_connection_pool(self):
491
        """
492
        Instantiates a Redis connection pool which can be used to
493
494
        communicate with the message broker
        """
495
496
        if self.verbose or self.report:
            print("Attempting to connect to redis server at {}:{}/{}".format(
u214892's avatar
u214892 committed
497
498
499
                self.remote_host,
                self.remote_port,
                self.remote_db))
500
501

        self.redis_pool = redis.ConnectionPool(
u214892's avatar
u214892 committed
502
503
504
505
506
            host=self.remote_host,
            port=self.remote_port,
            db=self.remote_db,
            password=self.remote_password
        )
507
        self.redis_conn = redis.Redis(connection_pool=self.redis_pool)
508
509

    def get_redis_connection(self):
510
511
512
513
        """
        Obtains a new redis connection from a previously instantiated
        redis connection pool
        """
514
        return self.redis_conn
515
516

    def _error_template(self, payload):
517
        """
518
        Simple helper function to pass a payload as a part of a
519
520
        flatland comms error template.
        """
521
522
523
524
525
526
        _response = {}
        _response['type'] = messages.FLATLAND_RL.ERROR
        _response['payload'] = payload
        return _response

    def get_next_command(self):
527
        """
528
529
530
        A helper function to obtain the next command, which transparently
        also deals with things like unpacking of the command from the
        packed message, and consider the timeouts, etc when trying to
531
532
        fetch a new command.
        """
MasterScrat's avatar
MasterScrat committed
533

534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
        COMMAND_TIMEOUT = DEFAULT_COMMAND_TIMEOUT
        """
        Handle case specific timeouts :
            - INTIAL_PLANNING_TIMEOUT
                The timeout between an env_create call and the first env_step call
            - PER_STEP_TIMEOUT
                The timeout between two consecutive env_step calls
        """
        if self.previous_command['type'] == messages.FLATLAND_RL.ENV_CREATE:
            """
            In case the previous command is an env_create, then leave 
            a but more time for the intial planning
            """
            COMMAND_TIMEOUT = INTIAL_PLANNING_TIMEOUT
        elif self.previous_command['type'] == messages.FLATLAND_RL.ENV_STEP:
            """
            Use the per_step_time for all timesteps between two env_step calls
            # Corner Case : 
                - Are there any reasons why a call between the last env_step call 
                and the subsequent env_create call will take an excessively large 
                amount of time (>5s in this case)
            """
            COMMAND_TIMEOUT = PER_STEP_TIMEOUT
        elif self.previous_command['type'] == messages.FLATLAND_RL.ENV_SUBMIT:
            """
            If the user has already done an env_submit call, then the timeout 
            can be an arbitrarily large number.
            """
MasterScrat's avatar
MasterScrat committed
562
            COMMAND_TIMEOUT = 10 ** 6
563

564
565
566
        if self.disable_timeouts:
            COMMAND_TIMEOUT = None

567
        @timeout_decorator.timeout(COMMAND_TIMEOUT, use_signals=use_signals_in_timeout)  # timeout for each command
568
569
570
571
572
573
574
575
576
577
        def _get_next_command(command_channel, _redis):
            """
            A low level wrapper for obtaining the next command from a
            pre-agreed command channel.
            At the momment, the communication protocol uses lpush for pushing
            in commands, and brpop for reading out commands.
            """
            command = _redis.brpop(command_channel)[1]
            return command

578
        # try:
579
        if True:
580
            _redis = self.get_redis_connection()
581
            command = _get_next_command(self.command_channel, _redis)
582
583
            if self.verbose or self.report:
                print("Command Service: ", command)
584
585
586
587
588
589
590

        if self.use_pickle:
            command = pickle.loads(command)
        else:
            command = msgpack.unpackb(
                command,
                object_hook=m.decode,
591
592
                strict_map_key=False,  # msgpack 1.0
                encoding="utf8"  # msgpack 1.0
593
            )
594
595
        if self.verbose:
            print("Received Request : ", command)
596

597
        message_queue_latency = time.time() - command["timestamp"]
598
        self.update_running_stats("message_queue_latency", message_queue_latency)
599
600
        return command

601
    def send_response(self, _command_response, command, suppress_logs=False):
602
603
604
        _redis = self.get_redis_connection()
        command_response_channel = command['response_channel']

605
        if self.verbose and not suppress_logs:
606
            print("Responding with : ", _command_response)
u214892's avatar
u214892 committed
607

608
609
610
611
        if self.use_pickle:
            sResponse = pickle.dumps(_command_response)
        else:
            sResponse = msgpack.packb(
u214892's avatar
u214892 committed
612
613
                _command_response,
                default=m.encode,
614
                use_bin_type=True)
615
616
617
        _redis.rpush(command_response_channel, sResponse)

    def send_error(self, error_dict, suppress_logs=False):
MasterScrat's avatar
MasterScrat committed
618
        """ For out-of-band errors like timeouts,
619
620
621
            where we do not have a command, so we have no response channel!
        """
        _redis = self.get_redis_connection()
622
        print("Sending error : ", error_dict)
623
624
625
626
627
628
629
630
631
632

        if self.use_pickle:
            sResponse = pickle.dumps(error_dict)
        else:
            sResponse = msgpack.packb(
                error_dict,
                default=m.encode,
                use_bin_type=True)

        _redis.rpush(self.error_channel, sResponse)
u214892's avatar
u214892 committed
633

634
635
636
637
    def handle_ping(self, command):
        """
        Handles PING command from the client.
        """
638
639
640
641
642
643
644
        service_version = flatland.__version__
        if "version" in command["payload"].keys():
            client_version = command["payload"]["version"]
        else:
            # 2.1.4 -> when the version mismatch check was added
            client_version = "2.1.4"

645
646
647
        _command_response = {}
        _command_response['type'] = messages.FLATLAND_RL.PONG
        _command_response['payload'] = {}
spmohanty's avatar
spmohanty committed
648
        if client_version not in SUPPORTED_CLIENT_VERSIONS:
649
650
651
652
653
654
655
            _command_response['type'] = messages.FLATLAND_RL.ERROR
            _command_response['payload']['message'] = \
                "Client-Server Version Mismatch => " + \
                "[ Client Version : {} ] ".format(client_version) + \
                "[ Server Version : {} ] ".format(service_version)
            self.send_response(_command_response, command)
            raise Exception(_command_response['payload']['message'])
656

657
        self.send_response(_command_response, command)
658
659

    def handle_env_create(self, command):
660
661
662
        """
        Handles a ENV_CREATE command from the client
        """
663
664
665

        # Check if the previous episode was finished
        if not self.simulation_done and not self.evaluation_done:
666
667
668
669
            _command_response = self._error_template("CAN'T CREATE NEW ENV BEFORE PREVIOUS IS DONE")
            self.send_response(_command_response, command)
            raise Exception(_command_response['payload'])

670
        self.simulation_count += 1
671
        self.simulation_done = False
672

673
674
675
676
        if self.simulation_count == 0:
            # Very first episode: start the overall timer
            self.overall_start_time = time.time()

677
678
679
        # reset the timeout flag / state.
        self.state_env_timed_out = False

680
681
682
683
        # Check if we have finished all the available envs
        if self.simulation_count >= len(self.env_file_paths):
            self.evaluation_done = True
            # Hack - just ensure these are set
MasterScrat's avatar
MasterScrat committed
684
            test_env_file_path = self.env_file_paths[self.simulation_count - 1]
685
686
687
688
            env_test, env_level = self.get_env_test_and_level(test_env_file_path)
        else:
            test_env_file_path = self.env_file_paths[self.simulation_count]
            env_test, env_level = self.get_env_test_and_level(test_env_file_path)
689
690
691
692
693
694
695
696
697
698
699

        # Did we just finish a test, and if yes did it reach high enough mean percentage done?
        if self.current_test != env_test and env_test != 0:
            if self.current_test not in self.simulation_percentage_complete_per_test:
                print("No environment was finished at all during test {}!".format(self.current_test))
                mean_test_complete_percentage = 0.0
            else:
                mean_test_complete_percentage = np.mean(self.simulation_percentage_complete_per_test[self.current_test])

            if mean_test_complete_percentage < TEST_MIN_PERCENTAGE_COMPLETE_MEAN:
                print("=" * 15)
MasterScrat's avatar
MasterScrat committed
700
701
                msg = "The mean percentage of done agents during the last Test ({} environments) was too low: {:.3f} < {}".format(
                    len(self.simulation_percentage_complete_per_test[self.current_test]),
702
703
704
705
706
707
708
709
                    mean_test_complete_percentage,
                    TEST_MIN_PERCENTAGE_COMPLETE_MEAN
                )
                print(msg, "Evaluation will stop.")
                self.termination_cause = msg
                self.evaluation_done = True

        if self.simulation_count < len(self.env_file_paths) and not self.evaluation_done:
710
711
712
            """
            There are still test envs left that are yet to be evaluated 
            """
713

MasterScrat's avatar
MasterScrat committed
714
            print("=" * 15)
715
            print("Evaluating {} ({}/{})".format(test_env_file_path, self.simulation_count, len(self.env_file_paths)))
716

717
718
719
720
            test_env_file_path = os.path.join(
                self.test_env_folder,
                test_env_file_path
            )
721
722
723
724

            self.current_test = env_test
            self.current_level = env_level

725
            del self.env
726
727

            self.env, _env_dict = RailEnvPersister.load_new(test_env_file_path)
728

729
730
            self.begin_simulation = time.time()

731
732
733
734
735
736
737
738
739
740
            # Update evaluation metadata for the previous episode
            self.update_evaluation_metadata()

            # Start adding placeholders for the new episode
            self.simulation_env_file_paths.append(
                os.path.relpath(
                    test_env_file_path,
                    self.test_env_folder
                ))  # relative path

741
            self.simulation_rewards.append(0)
742
            self.simulation_rewards_normalized.append(0)
743
            self.simulation_percentage_complete.append(0)
744
            self.simulation_times.append(0)
745
            self.simulation_steps.append(0)
746
            self.nb_malfunctioning_trains.append(0)
747

748
            self.current_step = 0
749

750
            _observation, _info = self.env.reset(
Erik Nygren's avatar
Erik Nygren committed
751
752
753
754
755
                regenerate_rail=True,
                regenerate_schedule=True,
                activate_agents=False,
                random_seed=RANDOM_SEED
            )
756

757
            if self.visualize:
758
759
760
761
762
                current_env_path = self.env_file_paths[self.simulation_count]
                if current_env_path in self.video_generation_envs:
                    self.env_renderer = RenderTool(self.env, gl="PILSVG", )
                elif self.env_renderer:
                    self.env_renderer = False
763

764
765
766
767
            _command_response = {}
            _command_response['type'] = messages.FLATLAND_RL.ENV_CREATE_RESPONSE
            _command_response['payload'] = {}
            _command_response['payload']['observation'] = _observation
768
            _command_response['payload']['env_file_path'] = self.env_file_paths[self.simulation_count]
769
770
            _command_response['payload']['info'] = _info
            _command_response['payload']['random_seed'] = RANDOM_SEED
771
772
773
774
775
        else:
            """
            All test env evaluations are complete
            """
            _command_response = {}
776
            _command_response['type'] = messages.FLATLAND_RL.ENV_CREATE_RESPONSE
777
778
            _command_response['payload'] = {}
            _command_response['payload']['observation'] = False
u214892's avatar
u214892 committed
779
            _command_response['payload']['env_file_path'] = False
780
            _command_response['payload']['info'] = False
781
            _command_response['payload']['random_seed'] = False
782
783

        self.send_response(_command_response, command)
784
785
786
        #####################################################################
        # Update evaluation state
        #####################################################################
787
        elapsed = time.time() - self.overall_start_time
788
        progress = np.clip(
789
            elapsed / OVERALL_TIMEOUT,
u214892's avatar
u214892 committed
790
            0, 1)
spmohanty's avatar
spmohanty committed
791

792
        mean_reward, mean_normalized_reward, sum_normalized_reward, mean_percentage_complete = self.compute_mean_scores()
spmohanty's avatar
spmohanty committed
793

794
795
796
        self.evaluation_state["state"] = "IN_PROGRESS"
        self.evaluation_state["progress"] = progress
        self.evaluation_state["simulation_count"] = self.simulation_count
797
798
        self.evaluation_state["score"]["score"] = sum_normalized_reward
        self.evaluation_state["score"]["score_secondary"] = mean_percentage_complete
799
        self.evaluation_state["meta"]["normalized_reward"] = mean_normalized_reward
800
        self.evaluation_state["meta"]["termination_cause"] = self.termination_cause
801
        self.handle_aicrowd_info_event(self.evaluation_state)
802
803

        self.episode_actions = []
804

805
    def handle_env_step(self, command):
806
807
        """
        Handles a ENV_STEP command from the client
u214892's avatar
u214892 committed
808
        TODO: Add a high level summary of everything thats happening here.
809
        """
810

811
812
        if self.state_env_timed_out or self.evaluation_done:
            print("Ignoring step command after timeout.")
813
814
            return

815
816
        _payload = command['payload']

817
        if not self.env:
818
            raise Exception("env_client.step called before env_client.env_create() call")
819
        if self.env.dones['__all__']:
820
821
822
            raise Exception(
                "Client attempted to perform an action on an Env which \
                has done['__all__']==True")
823

824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
        overall_elapsed = (time.time() - self.overall_start_time)
        if overall_elapsed > OVERALL_TIMEOUT:
            msg = "Reached overall time limit: took {:.2f}s, limit is {:.2f}s.".format(
                overall_elapsed, OVERALL_TIMEOUT
            )
            self.termination_cause = msg
            self.evaluation_done = True

            print("=" * 15)
            print(msg, "Evaluation will stop.")
            return
        # else:
        #     print("="*15)
        #     print("{}s left!".format(OVERALL_TIMEOUT - overall_elapsed))

839
        action = _payload['action']
840
        inference_time = _payload['inference_time']
841
842
843
844
        # We record this metric in two keys:
        #   - One for the current episode
        #   - One global
        self.update_running_stats("current_episode_controller_inference_time", inference_time)
845
846
        self.update_running_stats("controller_inference_time", inference_time)

847
        # Perform the step
848
        time_start = time.time()
849
        _observation, all_rewards, done, info = self.env.step(action)
850
        time_diff = time.time() - time_start
851
        self.update_running_stats("internal_env_step_time", time_diff)
852

MasterScrat's avatar
MasterScrat committed
853
854
        self.current_step += 1

855
        cumulative_reward = sum(all_rewards.values())
856
        self.simulation_rewards[-1] += cumulative_reward
857
        self.simulation_steps[-1] += 1
858
859
860
861
862
863
864
        """
        The normalized rewards normalize the reward for an 
        episode by dividing the whole reward by max-time-steps 
        allowed in that episode, and the number of agents present in 
        that episode
        """
        self.simulation_rewards_normalized[-1] += \
865
            (cumulative_reward / (
MasterScrat's avatar
MasterScrat committed
866
                self.env._max_episode_steps *
u214892's avatar
u214892 committed
867
                self.env.get_num_agents()
868
            ))
869

870
871
872
873
874
        # We count the number of agents that malfunctioned by checking how many have 1 more steps left before recovery
        num_malfunctioning = sum(agent.malfunction_data['malfunction'] == 1 for agent in self.env.agents)

        if self.verbose and num_malfunctioning > 0:
            print("Step {}: {} agents have malfunctioned and will recover next step".format(self.current_step, num_malfunctioning))
875
876
877

        self.nb_malfunctioning_trains[-1] += num_malfunctioning

878
        # record the actions before checking for done
879
880
        if self.action_dir is not None:
            self.episode_actions.append(action)
881

882
        # Is the episode over?
883
        if done["__all__"]:
884
885
            self.simulation_done = True

886
887
888
            if self.begin_simulation:
                # If begin simulation has already been initialized at least once
                # This adds the simulation time for the previous episode
889
                self.simulation_times[-1] = time.time() - self.begin_simulation
890

891
892
893
894
            # Compute percentage complete
            complete = 0
            for i_agent in range(self.env.get_num_agents()):
                agent = self.env.agents[i_agent]
895
                if agent.status in [RailAgentStatus.DONE_REMOVED]:
896
897
898
                    complete += 1
            percentage_complete = complete * 1.0 / self.env.get_num_agents()
            self.simulation_percentage_complete[-1] = percentage_complete
u214892's avatar
u214892 committed
899

900
901
902
903
904
905
906
907
            # adds 1.0 so we can add them up
            self.simulation_rewards_normalized[-1] += 1.0

            if self.current_test not in self.simulation_percentage_complete_per_test:
                self.simulation_percentage_complete_per_test[self.current_test] = []
            self.simulation_percentage_complete_per_test[self.current_test].append(percentage_complete)
            print("Percentage for test {}, level {}: {}".format(self.current_test, self.current_level, percentage_complete))

908
909
            if len(self.env.cur_episode) > 0:
                g3Ep = np.array(self.env.cur_episode)
MasterScrat's avatar
MasterScrat committed
910
                self.nb_deadlocked_trains.append(np.sum(g3Ep[-1, :, 5]))
911
912
913
            else:
                self.nb_deadlocked_trains.append(np.nan)

914
915
916
917
918
919
            print(
                "Evaluation finished in {} timesteps, {:.3f} seconds. Percentage agents done: {:.3f}. Normalized reward: {:.3f}. Number of malfunctions: {}.".format(
                    self.simulation_steps[-1],
                    self.simulation_times[-1],
                    self.simulation_percentage_complete[-1],
                    self.simulation_rewards_normalized[-1],
920
921
                    self.nb_malfunctioning_trains[-1],
                    self.nb_deadlocked_trains[-1]
922
923
924
                ))

            print("Total normalized reward so far: {:.3f}".format(sum(self.simulation_rewards_normalized)))
925

926
927
928
929
930
            # Write intermediate results
            if self.result_output_path:
                self.evaluation_metadata_df.to_csv(self.result_output_path)
                print("Wrote intermediate output results to : {}".format(self.result_output_path))

931
            if self.action_dir is not None:
932
                self.save_actions()
MasterScrat's avatar
MasterScrat committed
933

934
            if self.episode_dir is not None:
935
936
                self.save_episode()

937
            if self.merge_dir is not None:
938
939
                self.save_merged_env()

940
941
        # Record Frame
        if self.visualize:
942
            """
943
            Only generate and save the frames for environments which are separately provided
944
945
            in video_generation_indices param
            """
946
947
            current_env_path = self.env_file_paths[self.simulation_count]
            if current_env_path in self.video_generation_envs:
948
949
950
                self.env_renderer.render_env(
                    show=False,
                    show_observations=False,
951
952
                    show_predictions=False,
                    show_rowcols=False
953
954
                )

955
                self.env_renderer.gl.save_image(
u214892's avatar
u214892 committed
956
957
958
959
                    os.path.join(
                        self.vizualization_folder_name,
                        "flatland_frame_{:04d}.png".format(self.record_frame_step)
                    ))
960
                self.record_frame_step += 1
961

962
963
    def save_actions(self):
        sfEnv = self.env_file_paths[self.simulation_count]
MasterScrat's avatar
MasterScrat committed
964

965
        sfActions = self.action_dir + "/" + sfEnv.replace(".pkl", ".json")
966
967
968
969
970

        print("env path: ", sfEnv, " sfActions:", sfActions)

        if not os.path.exists(os.path.dirname(sfActions)):
            os.makedirs(os.path.dirname(sfActions))
MasterScrat's avatar
MasterScrat committed
971

972
        with open(sfActions, "w") as fOut:
973
            json.dump(self.episode_actions, fOut)
974

975
        self.episode_actions = []
MasterScrat's avatar
MasterScrat committed
976

977
978
    def save_episode(self):
        sfEnv = self.env_file_paths[self.simulation_count]
979
        sfEpisode = self.episode_dir + "/" + sfEnv
980
981
        print("env path: ", sfEnv, " sfEpisode:", sfEpisode)
        RailEnvPersister.save_episode(self.env, sfEpisode)
982
        # self.env.save_episode(sfEpisode)
MasterScrat's avatar
MasterScrat committed
983

984
985
    def save_merged_env(self):
        sfEnv = self.env_file_paths[self.simulation_count]
986
        sfMergeEnv = self.merge_dir + "/" + sfEnv
987
988
989
990
991
992

        if not os.path.exists(os.path.dirname(sfMergeEnv)):
            os.makedirs(os.path.dirname(sfMergeEnv))

        print("Input env path: ", sfEnv, " Merge File:", sfMergeEnv)
        RailEnvPersister.save_episode(self.env, sfMergeEnv)