Skip to content
Snippets Groups Projects
Commit 2dfc3174 authored by spmohanty's avatar spmohanty
Browse files

Addresses #117 - Implement basic client, service protocol

parent 080b5d03
No related branches found
No related tags found
No related merge requests found
# -*- coding: utf-8 -*-
"""Top-level package for flatland."""
__author__ = """S.P. Mohanty"""
__email__ = 'mohanty@aicrowd.com'
import redis
import json
import os
import glob
import pkg_resources
import sys
import numpy as np
import msgpack
import msgpack_numpy as m
m.patch()
import hashlib
import random
from flatland.evaluators import messages
from flatland.evaluators.utils import get_all_env_pickle_files
from flatland.envs.rail_env import RailEnv
from flatland.envs.generators import rail_from_file
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
import time
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class FlatlandRemoteClient(object):
"""
Redis client to interface with flatland-rl remote-evaluation-service
The Docker container hosts a redis-server inside the container.
This client connects to the same redis-server, and communicates with the service.
The service eventually will reside outside the docker container, and will communicate
with the client only via the redis-server of the docker container.
On the instantiation of the docker container, one service will be instantiated parallely.
The service will accepts commands at "`service_id`::commands"
where `service_id` is either provided as an `env` variable or is
instantiated to "flatland_rl_redis_service_id"
"""
def __init__(self,
remote_host='127.0.0.1',
remote_port=6379,
remote_db=0,
remote_password=None,
verbose=False
):
self.remote_host = remote_host
self.remote_port = remote_port
self.remote_db = remote_db
self.remote_password = remote_password
self.redis_pool = redis.ConnectionPool(
host=remote_host,
port=remote_port,
db=remote_db,
password=remote_password)
self.namespace = "flatland-rl"
try:
self.service_id = os.environ['FLATLAND_RL_SERVICE_ID']
except KeyError:
self.service_id = "FLATLAND_RL_SERVICE_ID"
self.command_channel = "{}::{}::commands".format(
self.namespace,
self.service_id
)
self.verbose = verbose
self.env = None
self.ping_pong()
def get_redis_connection(self):
return redis.Redis(connection_pool=self.redis_pool)
def _generate_response_channel(self):
random_hash = hashlib.md5(
"{}".format(
random.randint(0, 10**10)
).encode('utf-8')).hexdigest()
response_channel = "{}::{}::response::{}".format( self.namespace,
self.service_id,
random_hash)
return response_channel
def _blocking_request(self, _request):
"""
request:
-command_type
-payload
-response_channel
response: (on response_channel)
- RESULT
* Send the payload on command_channel (self.namespace+"::command")
** redis-left-push (LPUSH)
* Keep listening on response_channel (BLPOP)
"""
assert type(_request) ==type({})
_request['response_channel'] = self._generate_response_channel()
_redis = self.get_redis_connection()
"""
The client always pushes in the left
and the service always pushes in the right
"""
if self.verbose: print("Request : ", _response)
# Push request in command_channels
# Note: The patched msgpack supports numpy arrays
payload = msgpack.packb(_request, default=m.encode, use_bin_type=True)
_redis.lpush(self.command_channel, payload)
# Wait with a blocking pop for the response
_response = _redis.blpop(_request['response_channel'])[1]
if self.verbose: print("Response : ", _response)
_response = msgpack.unpackb(_response, object_hook=m.decode, encoding="utf8")
if _response['type'] == messages.FLATLAND_RL.ERROR:
raise Exception(str(_response))
else:
return _response
def ping_pong(self):
"""
Official Handshake with the evaluation service
Send a PING
and wait for PONG
If not PONG, raise error
"""
_request = {}
_request['type'] = messages.FLATLAND_RL.PING
_request['payload'] = {}
_response = self._blocking_request(_request)
if _response['type'] != messages.FLATLAND_RL.PONG:
raise Exception(
"Unable to perform handshake with the redis service. \
Expected PONG; received {}".format(json.dumps(_response)))
else:
return True
def env_create(self):
_request = {}
_request['type'] = messages.FLATLAND_RL.ENV_CREATE
_request['payload'] = {}
_response = self._blocking_request(_request)
observation = _response['payload']['observation']
test_env_file_path = _response['payload']['env_file_path']
self.env = RailEnv(
width=1,
height=1,
rail_generator=rail_from_file(test_env_file_path),
obs_builder_object=TreeObsForRailEnv(
max_depth=3,
predictor=ShortestPathPredictorForRailEnv()
)
)
self.env._max_episode_steps = \
int(1.5 * (self.env.width + self.env.height))
_ = self.env.reset()
# Use the observation from the remote service instead
return observation
def env_step(self, action, render=False):
"""
Respond with [observation, reward, done, info]
"""
_request = {}
_request['type'] = messages.FLATLAND_RL.ENV_STEP
_request['payload'] = {}
_request['payload']['action'] = action
_response = self._blocking_request(_request)
_payload = _response['payload']
observation = _payload['observation']
reward = _payload['reward']
done = _payload['done']
info = _payload['info']
return [observation, reward, done, info]
def submit(self):
_request = {}
_request['type'] = messages.FLATLAND_RL.ENV_SUBMIT
_request['payload'] = {}
_response = self._blocking_request(_request)
if os.getenv("AICROWD_BLOCKING_SUBMIT"):
"""
If the submission is supposed to happen as a blocking submit,
then wait indefinitely for the evaluator to decide what to
do with the container.
"""
while True:
time.sleep(10)
return _response['payload']
if __name__ == "__main__":
env_client = FlatlandRemoteClient()
def my_controller(obs, _env):
_action = {}
for _idx, _ in enumerate(_env.agents):
_action[_idx] = np.random.randint(0, 5)
return _action
obs = True
episode = 0
while obs:
obs = env_client.env_create()
print("Episode : {}".format(episode))
print(obs)
print(env_client.env.width)
print(env_client.env.height)
episode += 1
class FLATLAND_RL:
PING = "FLATLAND_RL.PING"
PONG = "FLATLAND_RL.PONG"
ENV_CREATE = "FLATLAND_RL.ENV_CREATE"
ENV_CREATE_RESPONSE = "FLATLAND_RL.ENV_CREATE_RESPONSE"
ENV_RESET = "FLATLAND_RL.ENV_RESET"
ENV_RESET_RESPONSE = "FLATLAND_RL.ENV_RESET_RESPONSE"
ENV_STEP = "FLATLAND_RL.ENV_STEP"
ENV_STEP_RESPONSE = "FLATLAND_RL.ENV_STEP_RESPONSE"
ENV_SUBMIT = "FLATLAND_RL.ENV_SUBMIT"
ENV_SUBMIT_RESPONSE = "FLATLAND_RL.ENV_SUBMIT_RESPONSE"
ERROR = "FLATLAND_RL.ERROR"
\ No newline at end of file
#!/usr/bin/env python
from __future__ import print_function
import redis
from flatland.envs.generators import rail_from_file
from flatland.envs.rail_env import RailEnv
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.evaluators import messages
import json
import numpy as np
import msgpack
import msgpack_numpy as m
m.patch()
import flatland
import os
import timeout_decorator
import time
########################################################
# CONSTANTS
########################################################
PER_STEP_TIMEOUT = 5*60 # 5 minutes
class FlatlandRemoteEvaluationService:
def __init__( self,
test_env_folder="/tmp",
flatland_rl_service_id = 'FLATLAND_RL_SERVICE_ID',
remote_host = '127.0.0.1',
remote_port = 6379,
remote_db = 0,
remote_password = None,
visualize = False,
report = None,
verbose = False):
# Test Env folder Paths
self.test_env_folder = test_env_folder
self.env_file_paths = self.get_env_filepaths()
# Logging and Reporting related vars
self.verbose = verbose
self.report = report
# Communication Protocol Related vars
self.namespace = "flatland-rl"
self.service_id = flatland_rl_service_id
self.command_channel = "{}::{}::commands".format(
self.namespace,
self.service_id
)
# Message Broker related vars
self.remote_host = remote_host
self.remote_port = remote_port
self.remote_db = remote_db
self.remote_password = remote_password
self.instantiate_redis_connection_pool()
# RailEnv specific variables
self.env = False
self.env_available = False
self.reward = 0
self.simulation_count = 0
self.simualation_rewards = []
self.simulation_times = []
self.begin_simulation = False
self.current_step = 0
self.visualize = visualize
def get_env_filepaths(self):
env_paths = []
folder_path = self.test_env_folder
for root, dirs, files in os.walk(folder_path):
for file in files:
if file.endswith(".pkl"):
env_paths.append(
os.path.join(root, file)
)
return sorted(env_paths)
def instantiate_redis_connection_pool(self):
if self.verbose or self.report:
print("Attempting to connect to redis server at {}:{}/{}".format(
self.remote_host,
self.remote_port,
self.remote_db)
)
self.redis_pool = redis.ConnectionPool(
host=self.remote_host,
port=self.remote_port,
db=self.remote_db,
password=self.remote_password
)
def get_redis_connection(self):
redis_conn = redis.Redis(connection_pool=self.redis_pool)
try:
redis_conn.ping()
except:
raise Exception(
"Unable to connect to redis server at {}:{} ."
"Are you sure there is a redis-server running at the "
"specified location ?".format(
self.remote_host,
self.remote_port
)
)
return redis_conn
def _error_template(self, payload):
_response = {}
_response['type'] = messages.FLATLAND_RL.ERROR
_response['payload'] = payload
return _response
@timeout_decorator.timeout(PER_STEP_TIMEOUT)# timeout for each command
def _get_next_command(self, _redis):
command = _redis.brpop(self.command_channel)[1]
return command
def get_next_command(self):
try:
_redis = self.get_redis_connection()
command = self._get_next_command(_redis)
if self.verbose or self.report:
print("Command Service: ", command)
except timeout_decorator.timeout_decorator.TimeoutError:
raise Exception(
"Timeout in step {} of simulation {}".format(
self.current_step,
self.simulation_count
))
command_response_channel = "default_response_channel"
command = msgpack.unpackb(
command,
object_hook=m.decode,
encoding="utf8"
)
if self.verbose:
print("Received Request : ", command)
return command
def handle_ping(self, command):
_redis = self.get_redis_connection()
command_response_channel = command['response_channel']
_command_response = {}
_command_response['type'] = messages.FLATLAND_RL.PONG
_command_response['payload'] = {}
if self.verbose: print("Responding with : ", _command_response)
_redis.rpush(
command_response_channel,
msgpack.packb(
_command_response,
default=m.encode,
use_bin_type=True)
)
def handle_env_create(self, command):
_redis = self.get_redis_connection()
command_response_channel = command['response_channel']
_payload = command['payload']
if self.simulation_count < len(self.env_file_paths):
"""
There are still test envs left that are yet to be evaluated
"""
test_env_file_path = self.env_file_paths[self.simulation_count]
del self.env
self.env = RailEnv(
width=1,
height=1,
rail_generator=rail_from_file(test_env_file_path),
obs_builder_object=TreeObsForRailEnv(
max_depth=3,
predictor=ShortestPathPredictorForRailEnv()
)
)
# Set max episode steps allowed
self.env._max_episode_steps = \
int(1.5 * (self.env.width + self.env.height))
self.env_available = True
self.simulation_count += 1
if self.begin_simulation:
# If begin simulation has already been initialized
# atleast once
self.simulation_times.append(time.time()-self.begin_simulation)
self.begin_simulation = time.time()
self.simualation_rewards.append(0)
self.current_step = 0
_observation = self.env.reset()
_command_response = {}
_command_response['type'] = messages.FLATLAND_RL.ENV_CREATE_RESPONSE
_command_response['payload'] = {}
_command_response['payload']['observation'] = _observation
_command_response['payload']['env_file_path'] = test_env_file_path
if self.verbose: print("Responding with : ", _command_response)
_redis.rpush(
command_response_channel,
msgpack.packb(
_command_response,
default=m.encode,
use_bin_type=True)
)
else:
"""
All test env evaluations are complete
"""
_command_response = {}
_command_response['type'] = messages.FLATLAND_RL.ENV_RESET_RESPONSE
_command_response['payload'] = {}
_command_response['payload']['observation'] = False
_command_response['payload']['env_file_path'] = False
if self.verbose: print("Responding with : ", _command_response)
_redis.rpush(
command_response_channel,
msgpack.packb(
_command_response,
default=m.encode,
use_bin_type=True)
)
def run(self):
print("Listening for commands at : ", self.command_channel)
while True:
command = self.get_next_command()
if self.verbose:
print("Self.Reward : ", self.reward)
print("Current Simulation : ", self.simulation_count)
if self.env_file_paths \
and self.simulation_count < len(self.env_file_paths):
print("Current Env Path : ", \
self.env_file_paths[self.simulation_count]
)
try:
if command['type'] == messages.FLATLAND_RL.PING:
"""
INITIAL HANDSHAKE : Respond with PONG
"""
self.handle_ping(command)
elif command['type'] == messages.FLATLAND_RL.ENV_CREATE:
"""
ENV_CREATE
Respond with an internal _env object
"""
self.handle_env_create(command)
elif command['type'] == messages.FLATLAND_RL.ENV_RESET:
"""
ENV_RESET
Respond with observation from next simulation or
False if no simulations are left
"""
self.simulation_count += 1
if self.begin_simulation:
self.simulation_times.append(time.time()-self.begin_simulation)
self.begin_simulation = time.time()
if self.seed_map and self.simulation_count < len(self.seed_map):
_observation = self.env.reset(seed=self.seed_map[self.simulation_count], project=False)
self.simualation_rewards.append(0)
self.env_available = True
self.current_step = 0
#_observation = list(_observation)
_command_response = {}
_command_response['type'] = messages.FLATLAND_RL.ENV_RESET_RESPONSE
_command_response['payload'] = {}
_command_response['payload']['observation'] = _observation
if self.verbose: print("Responding with : ", _command_response)
_redis.rpush(command_response_channel, msgpack.packb(_command_response, default=m.encode, use_bin_type=True))
else:
_command_response = {}
_command_response['type'] = messages.FLATLAND_RL.ENV_RESET_RESPONSE
_command_response['payload'] = {}
_command_response['payload']['observation'] = False
if self.verbose: print("Responding with : ", _command_response)
_redis.rpush(command_response_channel, msgpack.packb(_command_response, default=m.encode, use_bin_type=True))
elif command['type'] == messages.FLATLAND_RL.ENV_STEP:
"""
ENV_STEP
Request : Action array
Respond with updated [observation,reward,done,info] after step
"""
args = command['payload']
action = args['action']
if self.env and self.env_available:
[_observation, reward, done, info] = self.env.step(action)
else:
if self.env:
raise Exception("Attempt to call `step` function after max_steps={} in a single simulation. Please reset your environment before calling the `step` function after max_step s".format(self.max_steps))
else:
raise Exception("Attempt to call `step` function on a non existent `env`")
self.reward += reward
self.simualation_rewards[-1] += reward
self.current_step += 1
#_observation = np.array(_observation).tolist()
if self.current_step >= self.max_steps:
_command_response = {}
_command_response['type'] = messages.FLATLAND_RL.ENV_STEP_RESPONSE
_command_response['payload'] = {}
_command_response['payload']['observation'] = _observation
_command_response['payload']['reward'] = reward
_command_response['payload']['done'] = True
_command_response['payload']['info'] = info
"""
Mark env as unavailable until next reset
"""
self.env_available = False
else:
_command_response = {}
_command_response['type'] = messages.FLATLAND_RL.ENV_STEP_RESPONSE
_command_response['payload'] = {}
_command_response['payload']['observation'] = _observation
_command_response['payload']['reward'] = reward
_command_response['payload']['done'] = done
_command_response['payload']['info'] = info
if done:
"""
Mark env as unavailable until next reset
"""
self.env_available = False
if self.verbose: print("Responding with : ", _command_response)
if self.verbose: print("Current Step : ", self.current_step)
_redis.rpush(command_response_channel, msgpack.packb(_command_response, default=m.encode, use_bin_type=True))
elif command['type'] == messages.FLATLAND_RL.ENV_SUBMIT:
"""
ENV_SUBMIT
Submit the final cumulative reward
"""
_response = {}
_response['type'] = messages.FLATLAND_RL.ENV_SUBMIT_RESPONSE
_payload = {}
_payload['mean_reward'] = np.float(self.reward)/len(self.seed_map) #Mean reward
_payload['simulation_rewards'] = self.simualation_rewards
_payload['simulation_times'] = self.simulation_times
_response['payload'] = _payload
_redis.rpush(command_response_channel, msgpack.packb(_response, default=m.encode, use_bin_type=True))
elif command['type'] == messages.FLATLAND_RL.ENV_SUBMIT:
if self.verbose: print("Responding with : ", _response)
return _response
else:
_error = self._error_template(
"UNKNOWN_REQUEST:{}".format(
str(command)))
if self.verbose:print("Responding with : ", _error)
_redis.rpush(command_response_channel, msgpack.packb(_error, default=m.encode, use_bin_type=True))
return _error
except Exception as e:
print("Error : ", str(e))
_redis.rpush( command_response_channel,
msgpack.packb(self._error_template(str(e)), default=m.encode, use_bin_type=True))
return self._error_template(str(e))
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Submit the result to AIcrowd')
parser.add_argument('--service_id', dest='service_id', default='FLATLAND_RL_SERVICE_ID', required=False)
parser.add_argument('--test_folder',
dest='test_folder',
default="/Users/spmohanty/work/SBB/submission-scoring/Envs-Small",
help="Folder containing the pickle files for the test envs",
required=False)
args = parser.parse_args()
test_folder = args.test_folder
grader = FlatlandRemoteEvaluationService(
test_env_folder=test_folder,
flatland_rl_service_id=args.service_id,
verbose=True
)
result = grader.run()
if result['type'] == messages.FLATLAND_RL.ENV_SUBMIT_RESPONSE:
cumulative_results = result['payload']
print("Results : ", cumulative_results)
elif result['type'] == messages.FLATLAND_RL.ERROR:
error = result['payload']
raise Exception("Evaluation Failed : {}".format(str(error)))
else:
#Evaluation failed
print("Evaluation Failed : ", result['payload'])
...@@ -3,6 +3,7 @@ tox>=3.5.2 ...@@ -3,6 +3,7 @@ tox>=3.5.2
twine>=1.12.1 twine>=1.12.1
pytest>=3.8.2 pytest>=3.8.2
pytest-runner>=4.2 pytest-runner>=4.2
crowdai-api>=0.1.21
numpy>=1.16.2 numpy>=1.16.2
recordtype>=1.3 recordtype>=1.3
xarray>=0.11.3 xarray>=0.11.3
...@@ -17,5 +18,6 @@ pyarrow>=0.13.0 ...@@ -17,5 +18,6 @@ pyarrow>=0.13.0
importlib-metadata>=0.17 importlib-metadata>=0.17
importlib-resources>=1.0.1 importlib-resources>=1.0.1
six>=1.12.0 six>=1.12.0
timeout-decorator>=0.4.1
attrs attrs
ushlex ushlex
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment