Skip to content
Snippets Groups Projects
Commit 6c6e884e authored by spmohanty's avatar spmohanty
Browse files

Addresses #117 - Fix linter errors in client.py

parent cd84462b
No related branches found
No related tags found
No related merge requests found
import redis import redis
import json import json
import os import os
import glob
import pkg_resources
import sys
import numpy as np import numpy as np
import msgpack import msgpack
import msgpack_numpy as m import msgpack_numpy as m
m.patch()
import hashlib import hashlib
import random import random
from flatland.evaluators import messages from flatland.evaluators import messages
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.generators import rail_from_file from flatland.envs.generators import rail_from_file
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv
import time import time
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
m.patch()
class FlatlandRemoteClient(object): class FlatlandRemoteClient(object):
""" """
...@@ -43,8 +38,7 @@ class FlatlandRemoteClient(object): ...@@ -43,8 +38,7 @@ class FlatlandRemoteClient(object):
remote_port=6379, remote_port=6379,
remote_db=0, remote_db=0,
remote_password=None, remote_password=None,
verbose=False verbose=False):
):
self.remote_host = remote_host self.remote_host = remote_host
self.remote_port = remote_port self.remote_port = remote_port
...@@ -57,7 +51,7 @@ class FlatlandRemoteClient(object): ...@@ -57,7 +51,7 @@ class FlatlandRemoteClient(object):
password=remote_password) password=remote_password)
self.namespace = "flatland-rl" self.namespace = "flatland-rl"
try: try:
self.service_id = os.environ['FLATLAND_RL_SERVICE_ID'] self.service_id = os.environ['FLATLAND_RL_SERVICE_ID']
except KeyError: except KeyError:
self.service_id = "FLATLAND_RL_SERVICE_ID" self.service_id = "FLATLAND_RL_SERVICE_ID"
self.command_channel = "{}::{}::commands".format( self.command_channel = "{}::{}::commands".format(
...@@ -77,9 +71,9 @@ class FlatlandRemoteClient(object): ...@@ -77,9 +71,9 @@ class FlatlandRemoteClient(object):
"{}".format( "{}".format(
random.randint(0, 10**10) random.randint(0, 10**10)
).encode('utf-8')).hexdigest() ).encode('utf-8')).hexdigest()
response_channel = "{}::{}::response::{}".format( self.namespace, response_channel = "{}::{}::response::{}".format(self.namespace,
self.service_id, self.service_id,
random_hash) random_hash)
return response_channel return response_channel
def _blocking_request(self, _request): def _blocking_request(self, _request):
...@@ -94,7 +88,7 @@ class FlatlandRemoteClient(object): ...@@ -94,7 +88,7 @@ class FlatlandRemoteClient(object):
** redis-left-push (LPUSH) ** redis-left-push (LPUSH)
* Keep listening on response_channel (BLPOP) * Keep listening on response_channel (BLPOP)
""" """
assert type(_request) ==type({}) assert isinstance(_request, dict)
_request['response_channel'] = self._generate_response_channel() _request['response_channel'] = self._generate_response_channel()
_redis = self.get_redis_connection() _redis = self.get_redis_connection()
...@@ -102,14 +96,16 @@ class FlatlandRemoteClient(object): ...@@ -102,14 +96,16 @@ class FlatlandRemoteClient(object):
The client always pushes in the left The client always pushes in the left
and the service always pushes in the right and the service always pushes in the right
""" """
if self.verbose: print("Request : ", _response) if self.verbose:
print("Request : ", _request)
# Push request in command_channels # Push request in command_channels
# Note: The patched msgpack supports numpy arrays # Note: The patched msgpack supports numpy arrays
payload = msgpack.packb(_request, default=m.encode, use_bin_type=True) payload = msgpack.packb(_request, default=m.encode, use_bin_type=True)
_redis.lpush(self.command_channel, payload) _redis.lpush(self.command_channel, payload)
# Wait with a blocking pop for the response # Wait with a blocking pop for the response
_response = _redis.blpop(_request['response_channel'])[1] _response = _redis.blpop(_request['response_channel'])[1]
if self.verbose: print("Response : ", _response) if self.verbose:
print("Response : ", _response)
_response = msgpack.unpackb( _response = msgpack.unpackb(
_response, _response,
object_hook=m.decode, object_hook=m.decode,
...@@ -163,7 +159,7 @@ class FlatlandRemoteClient(object): ...@@ -163,7 +159,7 @@ class FlatlandRemoteClient(object):
self.env._max_episode_steps = \ self.env._max_episode_steps = \
int(1.5 * (self.env.width + self.env.height)) int(1.5 * (self.env.width + self.env.height))
_ = self.env.reset() self.env.reset()
# Use the observation from the remote service instead # Use the observation from the remote service instead
return observation return observation
...@@ -198,8 +194,10 @@ class FlatlandRemoteClient(object): ...@@ -198,8 +194,10 @@ class FlatlandRemoteClient(object):
time.sleep(10) time.sleep(10)
return _response['payload'] return _response['payload']
if __name__ == "__main__": if __name__ == "__main__":
env_client = FlatlandRemoteClient() env_client = FlatlandRemoteClient()
def my_controller(obs, _env): def my_controller(obs, _env):
_action = {} _action = {}
for _idx, _ in enumerate(_env.agents): for _idx, _ in enumerate(_env.agents):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment