Skip to content
Snippets Groups Projects
Commit 6c6e884e authored by spmohanty's avatar spmohanty
Browse files

Addresses #117 - Fix linter errors in client.py

parent cd84462b
No related branches found
No related tags found
No related merge requests found
import redis
import json
import os
import glob
import pkg_resources
import sys
import numpy as np
import msgpack
import msgpack_numpy as m
m.patch()
import hashlib
import random
from flatland.evaluators import messages
from flatland.envs.rail_env import RailEnv
from flatland.envs.generators import rail_from_file
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
import time
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
m.patch()
class FlatlandRemoteClient(object):
"""
......@@ -43,8 +38,7 @@ class FlatlandRemoteClient(object):
remote_port=6379,
remote_db=0,
remote_password=None,
verbose=False
):
verbose=False):
self.remote_host = remote_host
self.remote_port = remote_port
......@@ -57,7 +51,7 @@ class FlatlandRemoteClient(object):
password=remote_password)
self.namespace = "flatland-rl"
try:
self.service_id = os.environ['FLATLAND_RL_SERVICE_ID']
self.service_id = os.environ['FLATLAND_RL_SERVICE_ID']
except KeyError:
self.service_id = "FLATLAND_RL_SERVICE_ID"
self.command_channel = "{}::{}::commands".format(
......@@ -77,9 +71,9 @@ class FlatlandRemoteClient(object):
"{}".format(
random.randint(0, 10**10)
).encode('utf-8')).hexdigest()
response_channel = "{}::{}::response::{}".format( self.namespace,
self.service_id,
random_hash)
response_channel = "{}::{}::response::{}".format(self.namespace,
self.service_id,
random_hash)
return response_channel
def _blocking_request(self, _request):
......@@ -94,7 +88,7 @@ class FlatlandRemoteClient(object):
** redis-left-push (LPUSH)
* Keep listening on response_channel (BLPOP)
"""
assert type(_request) ==type({})
assert isinstance(_request, dict)
_request['response_channel'] = self._generate_response_channel()
_redis = self.get_redis_connection()
......@@ -102,14 +96,16 @@ class FlatlandRemoteClient(object):
The client always pushes in the left
and the service always pushes in the right
"""
if self.verbose: print("Request : ", _response)
if self.verbose:
print("Request : ", _request)
# Push request in command_channels
# Note: The patched msgpack supports numpy arrays
payload = msgpack.packb(_request, default=m.encode, use_bin_type=True)
_redis.lpush(self.command_channel, payload)
# Wait with a blocking pop for the response
_response = _redis.blpop(_request['response_channel'])[1]
if self.verbose: print("Response : ", _response)
if self.verbose:
print("Response : ", _response)
_response = msgpack.unpackb(
_response,
object_hook=m.decode,
......@@ -163,7 +159,7 @@ class FlatlandRemoteClient(object):
self.env._max_episode_steps = \
int(1.5 * (self.env.width + self.env.height))
_ = self.env.reset()
self.env.reset()
# Use the observation from the remote service instead
return observation
......@@ -198,8 +194,10 @@ class FlatlandRemoteClient(object):
time.sleep(10)
return _response['payload']
if __name__ == "__main__":
env_client = FlatlandRemoteClient()
def my_controller(obs, _env):
_action = {}
for _idx, _ in enumerate(_env.agents):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment