client.py 14.5 KB
Newer Older
u214892's avatar
u214892 committed
1
import hashlib
2
import json
u214892's avatar
u214892 committed
3
import logging
4
import os
u214892's avatar
u214892 committed
5
6
7
import random
import time

8
9
import msgpack
import msgpack_numpy as m
10
import pickle
u214892's avatar
u214892 committed
11
12
13
import numpy as np
import redis

14
import flatland
15
from flatland.envs.malfunction_generators import malfunction_from_file
u214892's avatar
u214892 committed
16
17
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file
18
from flatland.envs.line_generators import line_from_file
u214892's avatar
u214892 committed
19
from flatland.evaluators import messages
20
from flatland.core.env_observation_builder import DummyObservationBuilder
u214892's avatar
u214892 committed
21

22
23
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
24
25
m.patch()

26

27
28
29
30
31
class TimeoutException(StopAsyncIteration):
    """ Custom exception for evaluation timeouts. """
    pass


32
33
34
35
class FlatlandRemoteClient(object):
    """
        Redis client to interface with flatland-rl remote-evaluation-service
        The Docker container hosts a redis-server inside the container.
36
        This client connects to the same redis-server,
37
        and communicates with the service.
38
        The service eventually will reside outside the docker container,
39
        and will communicate
40
        with the client only via the redis-server of the docker container.
41
        On the instantiation of the docker container, one service will be
42
        instantiated parallely.
43
44
45
46
        The service will accepts commands at "`service_id`::commands"
        where `service_id` is either provided as an `env` variable or is
        instantiated to "flatland_rl_redis_service_id"
    """
u214892's avatar
u214892 committed
47
48
49
50
51
52
53

    def __init__(self,
                 remote_host='127.0.0.1',
                 remote_port=6379,
                 remote_db=0,
                 remote_password=None,
                 test_envs_root=None,
54
55
                 verbose=False,
                 use_pickle=False):
56
        self.use_pickle = use_pickle
57
58
59
60
61
        self.remote_host = remote_host
        self.remote_port = remote_port
        self.remote_db = remote_db
        self.remote_password = remote_password
        self.redis_pool = redis.ConnectionPool(
u214892's avatar
u214892 committed
62
63
64
65
            host=remote_host,
            port=remote_port,
            db=remote_db,
            password=remote_password)
66
67
        self.redis_conn = redis.Redis(connection_pool=self.redis_pool)

68
        self.namespace = "flatland-rl"
69
        self.service_id = os.getenv(
u214892's avatar
u214892 committed
70
71
72
            'FLATLAND_RL_SERVICE_ID',
            'FLATLAND_RL_SERVICE_ID'
        )
73
        self.command_channel = "{}::{}::commands".format(
u214892's avatar
u214892 committed
74
75
76
            self.namespace,
            self.service_id
        )
77
78
79
80
81

        # for timeout messages sent out-of-band
        self.error_channel = "{}::{}::errors".format(
            self.namespace, self.service_id)

82
83
84
85
        if test_envs_root:
            self.test_envs_root = test_envs_root
        else:
            self.test_envs_root = os.getenv(
u214892's avatar
u214892 committed
86
87
88
                'AICROWD_TESTS_FOLDER',
                '/tmp/flatland_envs'
            )
89
        self.current_env_path = None
90

91
92
93
94
95
        self.verbose = verbose

        self.env = None
        self.ping_pong()

96
        self.env_step_times = []
97
98
        self.stats = {}

99
    def update_running_stats(self, key, scalar):
100
101
102
103
104
        """
        Computes the running mean for certain params
        """
        mean_key = "{}_mean".format(key)
        counter_key = "{}_counter".format(key)
105
106
        min_key = "{}_min".format(key)
        max_key = "{}_max".format(key)
107
108

        try:
109
            # Update Mean
110
111
            self.stats[mean_key] = \
                ((self.stats[mean_key] * self.stats[counter_key]) + scalar) / (self.stats[counter_key] + 1)
112
113
114
115
116
117
118
            # Update min
            if scalar < self.stats[min_key]:
                self.stats[min_key] = scalar
            # Update max
            if scalar > self.stats[max_key]:
                self.stats[max_key] = scalar

119
120
            self.stats[counter_key] += 1
        except KeyError:
121
122
123
124
            self.stats[mean_key] = scalar
            self.stats[min_key] = scalar
            self.stats[max_key] = scalar
            self.stats[counter_key] = 1
125

126
    def get_redis_connection(self):
127
        return self.redis_conn
128
129
130

    def _generate_response_channel(self):
        random_hash = hashlib.md5(
u214892's avatar
u214892 committed
131
132
133
            "{}".format(
                random.randint(0, 10 ** 10)
            ).encode('utf-8')).hexdigest()
134
        response_channel = "{}::{}::response::{}".format(self.namespace,
u214892's avatar
u214892 committed
135
136
                                                         self.service_id,
                                                         random_hash)
137
138
        return response_channel

139
    def _remote_request(self, _request, blocking=True):
140
141
142
143
144
145
146
147
148
149
150
        """
            request:
                -command_type
                -payload
                -response_channel
            response: (on response_channel)
                - RESULT
            * Send the payload on command_channel (self.namespace+"::command")
                ** redis-left-push (LPUSH)
            * Keep listening on response_channel (BLPOP)
        """
151
        assert isinstance(_request, dict)
152
        _request['response_channel'] = self._generate_response_channel()
spmohanty's avatar
spmohanty committed
153
        _request['timestamp'] = time.time()
154
155
156
157
158
159

        _redis = self.get_redis_connection()
        """
            The client always pushes in the left
            and the service always pushes in the right
        """
160
161
        if self.verbose:
            print("Request : ", _request)
162
163
164
165
166
167
168
169
170
171
172
173

        # check for errors (essentially just timeouts, for now.)
        error_bytes = _redis.rpop(self.error_channel)
        if error_bytes is not None:
            if self.use_pickle:
                error_dict = pickle.loads(error_bytes)
            else:
                error_dict = msgpack.unpackb(
                    error_bytes,
                    object_hook=m.decode,
                    strict_map_key=False,  # new for msgpack 1.0?
                    encoding="utf8"  # remove for msgpack 1.0
174
175
                )
            print("Error received: ", error_dict)
176
            raise TimeoutException(error_dict["type"])
177

178
179
        # Push request in command_channels
        # Note: The patched msgpack supports numpy arrays
180
181
182
183
        if self.use_pickle:
            payload = pickle.dumps(_request)
        else:
            payload = msgpack.packb(_request, default=m.encode, use_bin_type=True)
184
        _redis.lpush(self.command_channel, payload)
185
186
187
188
189
190

        if blocking:
            # Wait with a blocking pop for the response
            _response = _redis.blpop(_request['response_channel'])[1]
            if self.verbose:
                print("Response : ", _response)
191
192
193
194
195
196
197
198
            if self.use_pickle:
                _response = pickle.loads(_response)
            else:
                _response = msgpack.unpackb(
                    _response,
                    object_hook=m.decode,
                    strict_map_key=False,  # new for msgpack 1.0?
                    encoding="utf8"  # remove for msgpack 1.0
199
                )
200
201
202
203
            if _response['type'] == messages.FLATLAND_RL.ERROR:
                raise Exception(str(_response["payload"]))
            else:
                return _response
204
205
206
207
208
209
210
211
212
213

    def ping_pong(self):
        """
            Official Handshake with the evaluation service
            Send a PING
            and wait for PONG
            If not PONG, raise error
        """
        _request = {}
        _request['type'] = messages.FLATLAND_RL.PING
214
        _request['payload'] = {
spmohanty's avatar
spmohanty committed
215
            "version": flatland.__version__
216
        }
217
        _response = self._remote_request(_request)
218
219
        if _response['type'] != messages.FLATLAND_RL.PONG:
            raise Exception(
220
                "Unable to perform handshake with the evaluation service. \
221
222
223
224
                Expected PONG; received {}".format(json.dumps(_response)))
        else:
            return True

225
226
    def env_create(self, obs_builder_object):
        """
227
            Create a local env and remote env on which the
228
229
230
231
            local agent can operate.
            The observation builder is only used in the local env
            and the remote env uses a DummyObservationBuilder
        """
232
        time_start = time.time()
233
234
        _request = {}
        _request['type'] = messages.FLATLAND_RL.ENV_CREATE
235
        _request['payload'] = {}
236
        _response = self._remote_request(_request)
237
        observation = _response['payload']['observation']
238
239
        info = _response['payload']['info']
        random_seed = _response['payload']['random_seed']
240
241
        test_env_file_path = _response['payload']['env_file_path']
        time_diff = time.time() - time_start
242
        self.update_running_stats("env_creation_wait_time", time_diff)
243

244
245
246
247
        if not observation:
            # If the observation is False,
            # then the evaluations are complete
            # hence return false
248
            return observation, info
249

250
251
252
        if self.verbose:
            print("Received Env : ", test_env_file_path)

253
254
255
256
        test_env_file_path = os.path.join(
            self.test_envs_root,
            test_env_file_path
        )
257
258
259
260
261
262
        if not os.path.exists(test_env_file_path):
            raise Exception(
                "\nWe cannot seem to find the env file paths at the required location.\n"
                "Did you remember to set the AICROWD_TESTS_FOLDER environment variable "
                "to point to the location of the Tests folder ? \n"
                "We are currently looking at `{}` for the tests".format(self.test_envs_root)
u214892's avatar
u214892 committed
263
            )
264

265
266
267
        if self.verbose:
            print("Current env path : ", test_env_file_path)
        self.current_env_path = test_env_file_path
268
269
        self.env = RailEnv(width=1, height=1, rail_generator=rail_from_file(test_env_file_path),
                           schedule_generator=schedule_from_file(test_env_file_path),
Erik Nygren's avatar
Erik Nygren committed
270
                           malfunction_generator_and_process_data=malfunction_from_file(test_env_file_path),
271
                           obs_builder_object=obs_builder_object)
272

273
        time_start = time.time()
274
275
        # Use the local observation
        # as the remote server uses a dummy observation builder
276
        local_observation, info = self.env.reset(
Erik Nygren's avatar
Erik Nygren committed
277
278
279
280
281
            regenerate_rail=True,
            regenerate_schedule=True,
            activate_agents=False,
            random_seed=random_seed
        )
282
        time_diff = time.time() - time_start
283
284
285
286
        self.update_running_stats("internal_env_reset_time", time_diff)

        # We use the last_env_step_time as an approximate measure of the inference time
        self.last_env_step_time = time.time()
287
        return local_observation, info
288
289
290
291

    def env_step(self, action, render=False):
        """
            Respond with [observation, reward, done, info]
292
        """
293
294
295
296
        # We use the last_env_step_time as an approximate measure of the inference time
        approximate_inference_time = time.time() - self.last_env_step_time
        self.update_running_stats("inference_time(approx)", approximate_inference_time)

297
298
299
300
        _request = {}
        _request['type'] = messages.FLATLAND_RL.ENV_STEP
        _request['payload'] = {}
        _request['payload']['action'] = action
301
        _request['payload']['inference_time'] = approximate_inference_time
302
303

        # Relay the action in a non-blocking way to the server
304
        # so that it can start doing an env.step on it in ~ parallel
305
        # Note - this can throw a Timeout
306
        self._remote_request(_request, blocking=False)
307

spmohanty's avatar
spmohanty committed
308
        # Apply the action in the local env
309
        time_start = time.time()
310
        local_observation, local_reward, local_done, local_info = \
311
            self.env.step(action)
312
313
        time_diff = time.time() - time_start
        # Compute a running mean of env step times
314
315
316
317
        self.update_running_stats("internal_env_step_time", time_diff)

        # We use the last_env_step_time as an approximate measure of the inference time
        self.last_env_step_time = time.time()
u214892's avatar
u214892 committed
318

319
        return [local_observation, local_reward, local_done, local_info]
320
321
322
323
324

    def submit(self):
        _request = {}
        _request['type'] = messages.FLATLAND_RL.ENV_SUBMIT
        _request['payload'] = {}
325
        _response = self._remote_request(_request)
326
327
328
329

        ######################################################################
        # Print Local Stats
        ######################################################################
Erik Nygren's avatar
Erik Nygren committed
330
331
        print("=" * 100)
        print("=" * 100)
332
        print("## Client Performance Stats")
Erik Nygren's avatar
Erik Nygren committed
333
        print("=" * 100)
334
335
        for _key in self.stats:
            if _key.endswith("_mean"):
336
337
338
339
340
                metric_name = _key.replace("_mean", "")
                mean_key = "{}_mean".format(metric_name)
                min_key = "{}_min".format(metric_name)
                max_key = "{}_max".format(metric_name)
                print("\t - {}\t => min: {} || mean: {} || max: {}".format(
341
342
343
344
                    metric_name,
                    self.stats[min_key],
                    self.stats[mean_key],
                    self.stats[max_key]))
Erik Nygren's avatar
Erik Nygren committed
345
        print("=" * 100)
346
347
348
349
350
351
352
353
354
355
        if os.getenv("AICROWD_BLOCKING_SUBMIT"):
            """
            If the submission is supposed to happen as a blocking submit,
            then wait indefinitely for the evaluator to decide what to 
            do with the container.
            """
            while True:
                time.sleep(10)
        return _response['payload']

356

357
if __name__ == "__main__":
358
    remote_client = FlatlandRemoteClient()
359

360

361
362
363
364
365
    def my_controller(obs, _env):
        _action = {}
        for _idx, _ in enumerate(_env.agents):
            _action[_idx] = np.random.randint(0, 5)
        return _action
u214892's avatar
u214892 committed
366

367

368
    my_observation_builder = DummyObservationBuilder()
369
370

    episode = 0
371
    obs = True
u214892's avatar
u214892 committed
372
    while obs:
373
        obs, info = remote_client.env_create(
u214892's avatar
u214892 committed
374
375
            obs_builder_object=my_observation_builder
        )
376
        if not obs:
377
378
379
380
            """
            The remote env returns False as the first obs
            when it is done evaluating all the individual episodes
            """
381
            break
382
383
        print("Episode : {}".format(episode))
        episode += 1
384

385
        print(remote_client.env.dones['__all__'])
386
387

        while True:
388
            action = my_controller(obs, remote_client.env)
389
            time_start = time.time()
390
391
392
393
394
395
396
397
398
399

            try:
                observation, all_rewards, done, info = remote_client.env_step(action)
                time_diff = time.time() - time_start
                print("Step Time : ", time_diff)
                if done['__all__']:
                    print("Current Episode : ", episode)
                    print("Episode Done")
                    print("Reward : ", sum(list(all_rewards.values())))
                    break
400
            except TimeoutException as err:
401
                print("Timeout: ", err)
402
403
                break

u214892's avatar
u214892 committed
404
    print("Evaluation Complete...")
405
    print(remote_client.submit())