diff --git a/flatland/evaluators/client.py b/flatland/evaluators/client.py index f9d8351fce0a8dabfdc8d0d785a692462ec86091..0eaacb57cacfc2a6c9e66919fd79bc18afb9f8c1 100644 --- a/flatland/evaluators/client.py +++ b/flatland/evaluators/client.py @@ -1,27 +1,22 @@ import redis import json import os -import glob -import pkg_resources -import sys import numpy as np import msgpack import msgpack_numpy as m -m.patch() import hashlib import random from flatland.evaluators import messages - from flatland.envs.rail_env import RailEnv from flatland.envs.generators import rail_from_file from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv - import time - import logging logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +m.patch() + class FlatlandRemoteClient(object): """ @@ -43,8 +38,7 @@ class FlatlandRemoteClient(object): remote_port=6379, remote_db=0, remote_password=None, - verbose=False - ): + verbose=False): self.remote_host = remote_host self.remote_port = remote_port @@ -57,7 +51,7 @@ class FlatlandRemoteClient(object): password=remote_password) self.namespace = "flatland-rl" try: - self.service_id = os.environ['FLATLAND_RL_SERVICE_ID'] + self.service_id = os.environ['FLATLAND_RL_SERVICE_ID'] except KeyError: self.service_id = "FLATLAND_RL_SERVICE_ID" self.command_channel = "{}::{}::commands".format( @@ -77,9 +71,9 @@ class FlatlandRemoteClient(object): "{}".format( random.randint(0, 10**10) ).encode('utf-8')).hexdigest() - response_channel = "{}::{}::response::{}".format( self.namespace, - self.service_id, - random_hash) + response_channel = "{}::{}::response::{}".format(self.namespace, + self.service_id, + random_hash) return response_channel def _blocking_request(self, _request): @@ -94,7 +88,7 @@ class FlatlandRemoteClient(object): ** redis-left-push (LPUSH) * Keep listening on response_channel (BLPOP) """ - assert type(_request) ==type({}) + assert isinstance(_request, dict) _request['response_channel'] = self._generate_response_channel() _redis = self.get_redis_connection() @@ -102,14 +96,16 @@ class FlatlandRemoteClient(object): The client always pushes in the left and the service always pushes in the right """ - if self.verbose: print("Request : ", _response) + if self.verbose: + print("Request : ", _request) # Push request in command_channels # Note: The patched msgpack supports numpy arrays payload = msgpack.packb(_request, default=m.encode, use_bin_type=True) _redis.lpush(self.command_channel, payload) # Wait with a blocking pop for the response _response = _redis.blpop(_request['response_channel'])[1] - if self.verbose: print("Response : ", _response) + if self.verbose: + print("Response : ", _response) _response = msgpack.unpackb( _response, object_hook=m.decode, @@ -163,7 +159,7 @@ class FlatlandRemoteClient(object): self.env._max_episode_steps = \ int(1.5 * (self.env.width + self.env.height)) - _ = self.env.reset() + self.env.reset() # Use the observation from the remote service instead return observation @@ -198,8 +194,10 @@ class FlatlandRemoteClient(object): time.sleep(10) return _response['payload'] + if __name__ == "__main__": env_client = FlatlandRemoteClient() + def my_controller(obs, _env): _action = {} for _idx, _ in enumerate(_env.agents):