client.py 14.4 KB
Newer Older
u214892's avatar
u214892 committed
1
import hashlib
2
import json
u214892's avatar
u214892 committed
3
import logging
4
import os
u214892's avatar
u214892 committed
5
6
7
import random
import time

8
9
import msgpack
import msgpack_numpy as m
10
import pickle
u214892's avatar
u214892 committed
11
12
13
import numpy as np
import redis

14
import flatland
15
from flatland.envs.malfunction_generators import FileMalfunctionGen
u214892's avatar
u214892 committed
16
17
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file
18
from flatland.envs.line_generators import line_from_file
u214892's avatar
u214892 committed
19
from flatland.evaluators import messages
20
from flatland.core.env_observation_builder import DummyObservationBuilder
u214892's avatar
u214892 committed
21

22
23
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
24
25
m.patch()

26

27
28
29
30
31
class TimeoutException(StopAsyncIteration):
    """ Custom exception for evaluation timeouts. """
    pass


32
33
34
35
class FlatlandRemoteClient(object):
    """
        Redis client to interface with flatland-rl remote-evaluation-service
        The Docker container hosts a redis-server inside the container.
36
        This client connects to the same redis-server,
37
        and communicates with the service.
38
        The service eventually will reside outside the docker container,
39
        and will communicate
40
        with the client only via the redis-server of the docker container.
41
        On the instantiation of the docker container, one service will be
42
        instantiated parallely.
43
44
45
46
        The service will accepts commands at "`service_id`::commands"
        where `service_id` is either provided as an `env` variable or is
        instantiated to "flatland_rl_redis_service_id"
    """
u214892's avatar
u214892 committed
47
48
49
50
51
52
53

    def __init__(self,
                 remote_host='127.0.0.1',
                 remote_port=6379,
                 remote_db=0,
                 remote_password=None,
                 test_envs_root=None,
54
55
                 verbose=False,
                 use_pickle=False):
56
        self.use_pickle = use_pickle
57
58
59
60
61
        self.remote_host = remote_host
        self.remote_port = remote_port
        self.remote_db = remote_db
        self.remote_password = remote_password
        self.redis_pool = redis.ConnectionPool(
u214892's avatar
u214892 committed
62
63
64
65
            host=remote_host,
            port=remote_port,
            db=remote_db,
            password=remote_password)
66
67
        self.redis_conn = redis.Redis(connection_pool=self.redis_pool)

68
        self.namespace = "flatland-rl"
69
        self.service_id = os.getenv(
u214892's avatar
u214892 committed
70
71
72
            'FLATLAND_RL_SERVICE_ID',
            'FLATLAND_RL_SERVICE_ID'
        )
73
        self.command_channel = "{}::{}::commands".format(
u214892's avatar
u214892 committed
74
75
76
            self.namespace,
            self.service_id
        )
77
78
79
80
81

        # for timeout messages sent out-of-band
        self.error_channel = "{}::{}::errors".format(
            self.namespace, self.service_id)

82
83
84
85
        if test_envs_root:
            self.test_envs_root = test_envs_root
        else:
            self.test_envs_root = os.getenv(
u214892's avatar
u214892 committed
86
87
88
                'AICROWD_TESTS_FOLDER',
                '/tmp/flatland_envs'
            )
89
        self.current_env_path = None
90

91
92
93
94
95
        self.verbose = verbose

        self.env = None
        self.ping_pong()

96
        self.env_step_times = []
97
98
        self.stats = {}

99
    def update_running_stats(self, key, scalar):
100
101
102
103
104
        """
        Computes the running mean for certain params
        """
        mean_key = "{}_mean".format(key)
        counter_key = "{}_counter".format(key)
105
106
        min_key = "{}_min".format(key)
        max_key = "{}_max".format(key)
107
108

        try:
109
            # Update Mean
110
111
            self.stats[mean_key] = \
                ((self.stats[mean_key] * self.stats[counter_key]) + scalar) / (self.stats[counter_key] + 1)
112
113
114
115
116
117
118
            # Update min
            if scalar < self.stats[min_key]:
                self.stats[min_key] = scalar
            # Update max
            if scalar > self.stats[max_key]:
                self.stats[max_key] = scalar

119
120
            self.stats[counter_key] += 1
        except KeyError:
121
122
123
124
            self.stats[mean_key] = scalar
            self.stats[min_key] = scalar
            self.stats[max_key] = scalar
            self.stats[counter_key] = 1
125

126
    def get_redis_connection(self):
127
        return self.redis_conn
128
129
130

    def _generate_response_channel(self):
        random_hash = hashlib.md5(
u214892's avatar
u214892 committed
131
132
133
            "{}".format(
                random.randint(0, 10 ** 10)
            ).encode('utf-8')).hexdigest()
134
        response_channel = "{}::{}::response::{}".format(self.namespace,
u214892's avatar
u214892 committed
135
136
                                                         self.service_id,
                                                         random_hash)
137
138
        return response_channel

139
    def _remote_request(self, _request, blocking=True):
140
141
142
143
144
145
146
147
148
149
150
        """
            request:
                -command_type
                -payload
                -response_channel
            response: (on response_channel)
                - RESULT
            * Send the payload on command_channel (self.namespace+"::command")
                ** redis-left-push (LPUSH)
            * Keep listening on response_channel (BLPOP)
        """
151
        assert isinstance(_request, dict)
152
        _request['response_channel'] = self._generate_response_channel()
spmohanty's avatar
spmohanty committed
153
        _request['timestamp'] = time.time()
154
155
156
157
158
159

        _redis = self.get_redis_connection()
        """
            The client always pushes in the left
            and the service always pushes in the right
        """
160
161
        if self.verbose:
            print("Request : ", _request)
162
163
164
165
166
167
168
169
170
171
172
173

        # check for errors (essentially just timeouts, for now.)
        error_bytes = _redis.rpop(self.error_channel)
        if error_bytes is not None:
            if self.use_pickle:
                error_dict = pickle.loads(error_bytes)
            else:
                error_dict = msgpack.unpackb(
                    error_bytes,
                    object_hook=m.decode,
                    strict_map_key=False,  # new for msgpack 1.0?
                    encoding="utf8"  # remove for msgpack 1.0
174
175
                )
            print("Error received: ", error_dict)
176
            raise TimeoutException(error_dict["type"])
177

178
179
        # Push request in command_channels
        # Note: The patched msgpack supports numpy arrays
180
181
182
183
        if self.use_pickle:
            payload = pickle.dumps(_request)
        else:
            payload = msgpack.packb(_request, default=m.encode, use_bin_type=True)
184
        _redis.lpush(self.command_channel, payload)
185
186
187
188
189
190

        if blocking:
            # Wait with a blocking pop for the response
            _response = _redis.blpop(_request['response_channel'])[1]
            if self.verbose:
                print("Response : ", _response)
191
192
193
194
195
196
197
198
            if self.use_pickle:
                _response = pickle.loads(_response)
            else:
                _response = msgpack.unpackb(
                    _response,
                    object_hook=m.decode,
                    strict_map_key=False,  # new for msgpack 1.0?
                    encoding="utf8"  # remove for msgpack 1.0
199
                )
200
201
202
203
            if _response['type'] == messages.FLATLAND_RL.ERROR:
                raise Exception(str(_response["payload"]))
            else:
                return _response
204
205
206
207
208
209
210
211
212
213

    def ping_pong(self):
        """
            Official Handshake with the evaluation service
            Send a PING
            and wait for PONG
            If not PONG, raise error
        """
        _request = {}
        _request['type'] = messages.FLATLAND_RL.PING
214
        _request['payload'] = {
spmohanty's avatar
spmohanty committed
215
            "version": flatland.__version__
216
        }
217
        _response = self._remote_request(_request)
218
219
        if _response['type'] != messages.FLATLAND_RL.PONG:
            raise Exception(
220
                "Unable to perform handshake with the evaluation service. \
221
222
223
224
                Expected PONG; received {}".format(json.dumps(_response)))
        else:
            return True

225
226
    def env_create(self, obs_builder_object):
        """
227
            Create a local env and remote env on which the
228
229
230
231
            local agent can operate.
            The observation builder is only used in the local env
            and the remote env uses a DummyObservationBuilder
        """
232
        time_start = time.time()
233
234
        _request = {}
        _request['type'] = messages.FLATLAND_RL.ENV_CREATE
235
        _request['payload'] = {}
236
        _response = self._remote_request(_request)
237
        observation = _response['payload']['observation']
238
239
        info = _response['payload']['info']
        random_seed = _response['payload']['random_seed']
240
241
        test_env_file_path = _response['payload']['env_file_path']
        time_diff = time.time() - time_start
242
        self.update_running_stats("env_creation_wait_time", time_diff)
243

244
245
246
247
        if not observation:
            # If the observation is False,
            # then the evaluations are complete
            # hence return false
248
            return observation, info
249

250
251
252
        if self.verbose:
            print("Received Env : ", test_env_file_path)

253
254
255
256
        test_env_file_path = os.path.join(
            self.test_envs_root,
            test_env_file_path
        )
257
258
259
260
261
262
        if not os.path.exists(test_env_file_path):
            raise Exception(
                "\nWe cannot seem to find the env file paths at the required location.\n"
                "Did you remember to set the AICROWD_TESTS_FOLDER environment variable "
                "to point to the location of the Tests folder ? \n"
                "We are currently looking at `{}` for the tests".format(self.test_envs_root)
u214892's avatar
u214892 committed
263
            )
264

265
266
267
        if self.verbose:
            print("Current env path : ", test_env_file_path)
        self.current_env_path = test_env_file_path
268
        self.env = RailEnv(width=1, height=1, rail_generator=rail_from_file(test_env_file_path),
269
                           line_generator=line_from_file(test_env_file_path),
270
                           malfunction_generator=FileMalfunctionGen(filename=test_env_file_path),
271
                           obs_builder_object=obs_builder_object)
272

273
        time_start = time.time()
274
275
        # Use the local observation
        # as the remote server uses a dummy observation builder
276
        local_observation, info = self.env.reset(
Erik Nygren's avatar
Erik Nygren committed
277
278
279
280
            regenerate_rail=True,
            regenerate_schedule=True,
            random_seed=random_seed
        )
281
        time_diff = time.time() - time_start
282
283
284
285
        self.update_running_stats("internal_env_reset_time", time_diff)

        # We use the last_env_step_time as an approximate measure of the inference time
        self.last_env_step_time = time.time()
286
        return local_observation, info
287
288
289
290

    def env_step(self, action, render=False):
        """
            Respond with [observation, reward, done, info]
291
        """
292
293
294
295
        # We use the last_env_step_time as an approximate measure of the inference time
        approximate_inference_time = time.time() - self.last_env_step_time
        self.update_running_stats("inference_time(approx)", approximate_inference_time)

296
297
298
299
        _request = {}
        _request['type'] = messages.FLATLAND_RL.ENV_STEP
        _request['payload'] = {}
        _request['payload']['action'] = action
300
        _request['payload']['inference_time'] = approximate_inference_time
301
302

        # Relay the action in a non-blocking way to the server
303
        # so that it can start doing an env.step on it in ~ parallel
304
        # Note - this can throw a Timeout
305
        self._remote_request(_request, blocking=False)
306

spmohanty's avatar
spmohanty committed
307
        # Apply the action in the local env
308
        time_start = time.time()
309
        local_observation, local_reward, local_done, local_info = \
310
            self.env.step(action)
311
312
        time_diff = time.time() - time_start
        # Compute a running mean of env step times
313
314
315
316
        self.update_running_stats("internal_env_step_time", time_diff)

        # We use the last_env_step_time as an approximate measure of the inference time
        self.last_env_step_time = time.time()
u214892's avatar
u214892 committed
317

318
        return [local_observation, local_reward, local_done, local_info]
319
320
321
322
323

    def submit(self):
        _request = {}
        _request['type'] = messages.FLATLAND_RL.ENV_SUBMIT
        _request['payload'] = {}
324
        _response = self._remote_request(_request)
325
326
327
328

        ######################################################################
        # Print Local Stats
        ######################################################################
Erik Nygren's avatar
Erik Nygren committed
329
330
        print("=" * 100)
        print("=" * 100)
331
        print("## Client Performance Stats")
Erik Nygren's avatar
Erik Nygren committed
332
        print("=" * 100)
333
334
        for _key in self.stats:
            if _key.endswith("_mean"):
335
336
337
338
339
                metric_name = _key.replace("_mean", "")
                mean_key = "{}_mean".format(metric_name)
                min_key = "{}_min".format(metric_name)
                max_key = "{}_max".format(metric_name)
                print("\t - {}\t => min: {} || mean: {} || max: {}".format(
340
341
342
343
                    metric_name,
                    self.stats[min_key],
                    self.stats[mean_key],
                    self.stats[max_key]))
Erik Nygren's avatar
Erik Nygren committed
344
        print("=" * 100)
345
346
347
348
349
350
351
352
353
354
        if os.getenv("AICROWD_BLOCKING_SUBMIT"):
            """
            If the submission is supposed to happen as a blocking submit,
            then wait indefinitely for the evaluator to decide what to 
            do with the container.
            """
            while True:
                time.sleep(10)
        return _response['payload']

355

356
if __name__ == "__main__":
357
    remote_client = FlatlandRemoteClient()
358

359

360
361
362
363
364
    def my_controller(obs, _env):
        _action = {}
        for _idx, _ in enumerate(_env.agents):
            _action[_idx] = np.random.randint(0, 5)
        return _action
u214892's avatar
u214892 committed
365

366

367
    my_observation_builder = DummyObservationBuilder()
368
369

    episode = 0
370
    obs = True
u214892's avatar
u214892 committed
371
    while obs:
372
        obs, info = remote_client.env_create(
u214892's avatar
u214892 committed
373
374
            obs_builder_object=my_observation_builder
        )
375
        if not obs:
376
377
378
379
            """
            The remote env returns False as the first obs
            when it is done evaluating all the individual episodes
            """
380
            break
381
382
        print("Episode : {}".format(episode))
        episode += 1
383

384
        print(remote_client.env.dones['__all__'])
385
386

        while True:
387
            action = my_controller(obs, remote_client.env)
388
            time_start = time.time()
389
390
391
392
393
394
395
396
397
398

            try:
                observation, all_rewards, done, info = remote_client.env_step(action)
                time_diff = time.time() - time_start
                print("Step Time : ", time_diff)
                if done['__all__']:
                    print("Current Episode : ", episode)
                    print("Episode Done")
                    print("Reward : ", sum(list(all_rewards.values())))
                    break
399
            except TimeoutException as err:
400
                print("Timeout: ", err)
401
402
                break

u214892's avatar
u214892 committed
403
    print("Evaluation Complete...")
404
    print(remote_client.submit())