Skip to content
Snippets Groups Projects
Commit afc43664 authored by mohanty's avatar mohanty
Browse files

Merge branch 'spm/remote-evaluation-service' into 'master'

Addresses #117 - Implement basic client, service protocol

See merge request flatland/flatland!120
parents 9da1d933 5de931f0
No related branches found
No related tags found
No related merge requests found
...@@ -111,3 +111,5 @@ ENV/ ...@@ -111,3 +111,5 @@ ENV/
images/test/ images/test/
test_save.dat test_save.dat
.visualizations
\ No newline at end of file
...@@ -73,3 +73,24 @@ class ObservationBuilder: ...@@ -73,3 +73,24 @@ class ObservationBuilder:
direction = np.zeros(4) direction = np.zeros(4)
direction[agent.direction] = 1 direction[agent.direction] = 1
return direction return direction
class DummyObservationBuilder(ObservationBuilder):
"""
DummyObservationBuilder class which returns dummy observations
This is used in the evaluation service
"""
def __init__(self):
self.observation_space = ()
def _set_env(self, env):
self.env = env
def reset(self):
pass
def get_many(self, handles=[]):
return True
def get(self, handle=0):
return True
# -*- coding: utf-8 -*-
"""Top-level package for flatland."""
__author__ = """S.P. Mohanty"""
__email__ = 'mohanty@aicrowd.com'
import os
import boto3
import uuid
import subprocess
###############################################################
# Expected Env Variables
###############################################################
# Default Values to be provided :
# AICROWD_IS_GRADING : true
# CROWDAI_IS_GRADING : true
# S3_BUCKET : aicrowd-production
# S3_UPLOAD_PATH_TEMPLATE : misc/flatland-rl-Media/{}.mp4
# AWS_ACCESS_KEY_ID
# AWS_SECRET_ACCESS_KEY
# http_proxy
# https_proxy
###############################################################
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID", False)
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY", False)
S3_BUCKET = os.getenv("S3_BUCKET", "aicrowd-production")
S3_UPLOAD_PATH_TEMPLATE = os.getenv("S3_UPLOAD_PATH_TEMPLATE", "misc/flatland-rl-Media/{}.mp4")
def get_boto_client():
if not AWS_ACCESS_KEY_ID or not AWS_SECRET_ACCESS_KEY:
raise Exception("AWS Credentials not provided..")
return boto3.client(
's3',
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY
)
def is_aws_configured():
if not AWS_ACCESS_KEY_ID or not AWS_SECRET_ACCESS_KEY:
return False
else:
return True
def is_grading():
return os.getenv("CROWDAI_IS_GRADING", False) or \
os.getenv("AICROWD_IS_GRADING", False)
def upload_to_s3(localpath):
s3 = get_boto_client()
if not S3_UPLOAD_PATH_TEMPLATE:
raise Exception("S3_UPLOAD_PATH_TEMPLATE not provided...")
if not S3_BUCKET:
raise Exception("S3_BUCKET not provided...")
image_target_key = S3_UPLOAD_PATH_TEMPLATE.format(str(uuid.uuid4()))
s3.put_object(
ACL="public-read",
Bucket=S3_BUCKET,
Key=image_target_key,
Body=open(localpath, 'rb')
)
return image_target_key
def make_subprocess_call(command, shell=False):
result = subprocess.run(
command.split(),
shell=shell,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
stdout = result.stdout.decode('utf-8')
stderr = result.stderr.decode('utf-8')
return result.returncode, stdout, stderr
def generate_movie_from_frames(frames_folder):
"""
Expects the frames in the frames_folder folder
and then use ffmpeg to generate the video
which writes the output to the frames_folder
"""
# Generate Thumbnail Video
print("Generating Thumbnail...")
frames_path = os.path.join(frames_folder, "flatland_frame_%04d.png")
thumb_output_path = os.path.join(frames_folder, "out_thumb.mp4")
return_code, output, output_err = make_subprocess_call(
"ffmpeg -r 7 -start_number 0 -i " +
frames_path +
" -c:v libx264 -vf fps=7 -pix_fmt yuv420p -s 320x320 " +
thumb_output_path
)
if return_code != 0:
raise Exception(output_err)
# Generate Normal Sized Video
print("Generating Normal Video...")
frames_path = os.path.join(frames_folder, "flatland_frame_%04d.png")
output_path = os.path.join(frames_folder, "out.mp4")
return_code, output, output_err = make_subprocess_call(
"ffmpeg -r 7 -start_number 0 -i " +
frames_path +
" -c:v libx264 -vf fps=7 -pix_fmt yuv420p -s 600x600 " +
output_path
)
if return_code != 0:
raise Exception(output_err)
return output_path, thumb_output_path
import redis
import json
import os
import numpy as np
import msgpack
import msgpack_numpy as m
import hashlib
import random
from flatland.evaluators import messages
from flatland.envs.rail_env import RailEnv
from flatland.envs.generators import rail_from_file
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
import time
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
m.patch()
def are_dicts_equal(d1, d2):
""" return True if all keys and values are the same """
return all(k in d2 and d1[k] == d2[k]
for k in d1) \
and all(k in d1 and d1[k] == d2[k]
for k in d2)
class FlatlandRemoteClient(object):
"""
Redis client to interface with flatland-rl remote-evaluation-service
The Docker container hosts a redis-server inside the container.
This client connects to the same redis-server,
and communicates with the service.
The service eventually will reside outside the docker container,
and will communicate
with the client only via the redis-server of the docker container.
On the instantiation of the docker container, one service will be
instantiated parallely.
The service will accepts commands at "`service_id`::commands"
where `service_id` is either provided as an `env` variable or is
instantiated to "flatland_rl_redis_service_id"
"""
def __init__(self,
remote_host='127.0.0.1',
remote_port=6379,
remote_db=0,
remote_password=None,
verbose=False):
self.remote_host = remote_host
self.remote_port = remote_port
self.remote_db = remote_db
self.remote_password = remote_password
self.redis_pool = redis.ConnectionPool(
host=remote_host,
port=remote_port,
db=remote_db,
password=remote_password)
self.namespace = "flatland-rl"
try:
self.service_id = os.environ['FLATLAND_RL_SERVICE_ID']
except KeyError:
self.service_id = "FLATLAND_RL_SERVICE_ID"
self.command_channel = "{}::{}::commands".format(
self.namespace,
self.service_id
)
self.verbose = verbose
self.env = None
self.ping_pong()
def get_redis_connection(self):
return redis.Redis(connection_pool=self.redis_pool)
def _generate_response_channel(self):
random_hash = hashlib.md5(
"{}".format(
random.randint(0, 10**10)
).encode('utf-8')).hexdigest()
response_channel = "{}::{}::response::{}".format(self.namespace,
self.service_id,
random_hash)
return response_channel
def _blocking_request(self, _request):
"""
request:
-command_type
-payload
-response_channel
response: (on response_channel)
- RESULT
* Send the payload on command_channel (self.namespace+"::command")
** redis-left-push (LPUSH)
* Keep listening on response_channel (BLPOP)
"""
assert isinstance(_request, dict)
_request['response_channel'] = self._generate_response_channel()
_redis = self.get_redis_connection()
"""
The client always pushes in the left
and the service always pushes in the right
"""
if self.verbose:
print("Request : ", _request)
# Push request in command_channels
# Note: The patched msgpack supports numpy arrays
payload = msgpack.packb(_request, default=m.encode, use_bin_type=True)
_redis.lpush(self.command_channel, payload)
# Wait with a blocking pop for the response
_response = _redis.blpop(_request['response_channel'])[1]
if self.verbose:
print("Response : ", _response)
_response = msgpack.unpackb(
_response,
object_hook=m.decode,
encoding="utf8")
if _response['type'] == messages.FLATLAND_RL.ERROR:
raise Exception(str(_response["payload"]))
else:
return _response
def ping_pong(self):
"""
Official Handshake with the evaluation service
Send a PING
and wait for PONG
If not PONG, raise error
"""
_request = {}
_request['type'] = messages.FLATLAND_RL.PING
_request['payload'] = {}
_response = self._blocking_request(_request)
if _response['type'] != messages.FLATLAND_RL.PONG:
raise Exception(
"Unable to perform handshake with the redis service. \
Expected PONG; received {}".format(json.dumps(_response)))
else:
return True
def env_create(self, obs_builder_object):
"""
Create a local env and remote env on which the
local agent can operate.
The observation builder is only used in the local env
and the remote env uses a DummyObservationBuilder
"""
_request = {}
_request['type'] = messages.FLATLAND_RL.ENV_CREATE
_request['payload'] = {}
_response = self._blocking_request(_request)
observation = _response['payload']['observation']
if not observation:
# If the observation is False,
# then the evaluations are complete
# hence return false
return observation
test_env_file_path = _response['payload']['env_file_path']
self.env = RailEnv(
width=1,
height=1,
rail_generator=rail_from_file(test_env_file_path),
obs_builder_object=obs_builder_object
)
self.env._max_episode_steps = \
int(1.5 * (self.env.width + self.env.height))
local_observation = self.env.reset()
# Use the local observation
# as the remote server uses a dummy observation builder
return local_observation
def env_step(self, action, render=False):
"""
Respond with [observation, reward, done, info]
"""
_request = {}
_request['type'] = messages.FLATLAND_RL.ENV_STEP
_request['payload'] = {}
_request['payload']['action'] = action
_response = self._blocking_request(_request)
_payload = _response['payload']
# remote_observation = _payload['observation']
remote_reward = _payload['reward']
remote_done = _payload['done']
remote_info = _payload['info']
# Replicate the action in the local env
local_observation, local_rewards, local_done, local_info = \
self.env.step(action)
assert are_dicts_equal(remote_reward, local_rewards)
assert are_dicts_equal(remote_done, local_done)
# Return local_observation instead of remote_observation
# as the remote_observation is build using a dummy observation
# builder
# We return the remote rewards and done as they are the
# once used by the evaluator
return [local_observation, remote_reward, remote_done, remote_info]
def submit(self):
_request = {}
_request['type'] = messages.FLATLAND_RL.ENV_SUBMIT
_request['payload'] = {}
_response = self._blocking_request(_request)
if os.getenv("AICROWD_BLOCKING_SUBMIT"):
"""
If the submission is supposed to happen as a blocking submit,
then wait indefinitely for the evaluator to decide what to
do with the container.
"""
while True:
time.sleep(10)
return _response['payload']
if __name__ == "__main__":
remote_client = FlatlandRemoteClient()
def my_controller(obs, _env):
_action = {}
for _idx, _ in enumerate(_env.agents):
_action[_idx] = np.random.randint(0, 5)
return _action
my_observation_builder = TreeObsForRailEnv(max_depth=3,
predictor=ShortestPathPredictorForRailEnv())
episode = 0
obs = True
while obs:
obs = remote_client.env_create(
obs_builder_object=my_observation_builder
)
if not obs:
"""
The remote env returns False as the first obs
when it is done evaluating all the individual episodes
"""
break
print("Episode : {}".format(episode))
episode += 1
print(remote_client.env.dones['__all__'])
while True:
action = my_controller(obs, remote_client.env)
observation, all_rewards, done, info = remote_client.env_step(action)
if done['__all__']:
print("Current Episode : ", episode)
print("Episode Done")
print("Reward : ", sum(list(all_rewards.values())))
break
print("Evaluation Complete...")
print(remote_client.submit())
class FLATLAND_RL:
PING = "FLATLAND_RL.PING"
PONG = "FLATLAND_RL.PONG"
ENV_CREATE = "FLATLAND_RL.ENV_CREATE"
ENV_CREATE_RESPONSE = "FLATLAND_RL.ENV_CREATE_RESPONSE"
ENV_RESET = "FLATLAND_RL.ENV_RESET"
ENV_RESET_RESPONSE = "FLATLAND_RL.ENV_RESET_RESPONSE"
ENV_STEP = "FLATLAND_RL.ENV_STEP"
ENV_STEP_RESPONSE = "FLATLAND_RL.ENV_STEP_RESPONSE"
ENV_SUBMIT = "FLATLAND_RL.ENV_SUBMIT"
ENV_SUBMIT_RESPONSE = "FLATLAND_RL.ENV_SUBMIT_RESPONSE"
ERROR = "FLATLAND_RL.ERROR"
\ No newline at end of file
This diff is collapsed.
...@@ -3,6 +3,8 @@ tox>=3.5.2 ...@@ -3,6 +3,8 @@ tox>=3.5.2
twine>=1.12.1 twine>=1.12.1
pytest>=3.8.2 pytest>=3.8.2
pytest-runner>=4.2 pytest-runner>=4.2
crowdai-api>=0.1.21
boto3>=1.9.194
numpy>=1.16.2 numpy>=1.16.2
recordtype>=1.3 recordtype>=1.3
xarray>=0.11.3 xarray>=0.11.3
...@@ -17,5 +19,6 @@ pyarrow>=0.13.0 ...@@ -17,5 +19,6 @@ pyarrow>=0.13.0
importlib-metadata>=0.17 importlib-metadata>=0.17
importlib-resources>=1.0.1 importlib-resources>=1.0.1
six>=1.12.0 six>=1.12.0
timeout-decorator>=0.4.1
attrs attrs
ushlex ushlex
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment