Skip to content
Snippets Groups Projects
Commit b299b8c6 authored by hagrid67's avatar hagrid67
Browse files

draft / initial commit with running test_eval_timeout.py

parent 829f4b0c
No related branches found
No related tags found
No related merge requests found
......@@ -7,6 +7,7 @@ import time
import msgpack
import msgpack_numpy as m
import pickle
import numpy as np
import redis
......@@ -45,8 +46,9 @@ class FlatlandRemoteClient(object):
remote_db=0,
remote_password=None,
test_envs_root=None,
verbose=False):
verbose=False,
use_pickle=False):
self.use_pickle=use_pickle
self.remote_host = remote_host
self.remote_port = remote_port
self.remote_db = remote_db
......@@ -67,6 +69,11 @@ class FlatlandRemoteClient(object):
self.namespace,
self.service_id
)
# for timeout messages sent out-of-band
self.error_channel = "{}::{}::errors".format(
self.namespace, self.service_id)
if test_envs_root:
self.test_envs_root = test_envs_root
else:
......@@ -84,6 +91,8 @@ class FlatlandRemoteClient(object):
self.env_step_times = []
self.stats = {}
def update_running_stats(self, key, scalar):
"""
Computes the running mean for certain params
......@@ -147,9 +156,28 @@ class FlatlandRemoteClient(object):
"""
if self.verbose:
print("Request : ", _request)
# check for errors (essentially just timeouts, for now.)
error_bytes = _redis.rpop(self.error_channel)
if error_bytes is not None:
if self.use_pickle:
error_dict = pickle.loads(error_bytes)
else:
error_dict = msgpack.unpackb(
error_bytes,
object_hook=m.decode,
strict_map_key=False, # new for msgpack 1.0?
encoding="utf8" # remove for msgpack 1.0
)
print("error received: ", error_dict)
raise StopAsyncIteration(error_dict["type"])
# Push request in command_channels
# Note: The patched msgpack supports numpy arrays
payload = msgpack.packb(_request, default=m.encode, use_bin_type=True)
if self.use_pickle:
payload = pickle.dumps(_request)
else:
payload = msgpack.packb(_request, default=m.encode, use_bin_type=True)
_redis.lpush(self.command_channel, payload)
if blocking:
......@@ -157,10 +185,15 @@ class FlatlandRemoteClient(object):
_response = _redis.blpop(_request['response_channel'])[1]
if self.verbose:
print("Response : ", _response)
_response = msgpack.unpackb(
_response,
object_hook=m.decode,
encoding="utf8")
if self.use_pickle:
_response = pickle.loads(_response)
else:
_response = msgpack.unpackb(
_response,
object_hook=m.decode,
strict_map_key=False, # new for msgpack 1.0?
encoding="utf8" # remove for msgpack 1.0
)
if _response['type'] == messages.FLATLAND_RL.ERROR:
raise Exception(str(_response["payload"]))
else:
......@@ -266,6 +299,7 @@ class FlatlandRemoteClient(object):
# Relay the action in a non-blocking way to the server
# so that it can start doing an env.step on it in ~ parallel
# Note - this can throw a Timeout
self._remote_request(_request, blocking=False)
# Apply the action in the local env
......@@ -348,13 +382,18 @@ if __name__ == "__main__":
while True:
action = my_controller(obs, remote_client.env)
time_start = time.time()
observation, all_rewards, done, info = remote_client.env_step(action)
time_diff = time.time() - time_start
print("Step Time : ", time_diff)
if done['__all__']:
print("Current Episode : ", episode)
print("Episode Done")
print("Reward : ", sum(list(all_rewards.values())))
try:
observation, all_rewards, done, info = remote_client.env_step(action)
time_diff = time.time() - time_start
print("Step Time : ", time_diff)
if done['__all__']:
print("Current Episode : ", episode)
print("Episode Done")
print("Reward : ", sum(list(all_rewards.values())))
break
except StopAsyncIteration as err:
print("Timeout: ", err)
break
print("Evaluation Complete...")
......
......@@ -7,11 +7,16 @@ class FLATLAND_RL:
ENV_RESET = "FLATLAND_RL.ENV_RESET"
ENV_RESET_RESPONSE = "FLATLAND_RL.ENV_RESET_RESPONSE"
ENV_RESET_TIMEOUT = "FLATLAND_RL.ENV_RESET_TIMEOUT"
ENV_STEP = "FLATLAND_RL.ENV_STEP"
ENV_STEP_RESPONSE = "FLATLAND_RL.ENV_STEP_RESPONSE"
ENV_STEP_TIMEOUT = "FLATLAND_RL.ENV_STEP_TIMEOUT"
ENV_SUBMIT = "FLATLAND_RL.ENV_SUBMIT"
ENV_SUBMIT_RESPONSE = "FLATLAND_RL.ENV_SUBMIT_RESPONSE"
ERROR = "FLATLAND_RL.ERROR"
......@@ -7,10 +7,13 @@ import random
import shutil
import time
import traceback
import json
import itertools
import crowdai_api
import msgpack
import msgpack_numpy as m
import pickle
import numpy as np
import pandas as pd
import redis
......@@ -26,6 +29,8 @@ from flatland.envs.schedule_generators import schedule_from_file
from flatland.evaluators import aicrowd_helpers
from flatland.evaluators import messages
from flatland.utils.rendertools import RenderTool
from flatland.envs.rail_env_utils import load_flatland_environment_from_file
from flatland.envs.persistence import RailEnvPersister
use_signals_in_timeout = True
if os.name == 'nt':
......@@ -92,14 +97,33 @@ class FlatlandRemoteEvaluationService:
visualize=False,
video_generation_envs=[],
report=None,
verbose=False,
actionDir=None,
episodeDir=None,
mergeDir=None,
use_pickle=False,
shuffle=True,
missing_only=False,
result_output_path=None,
verbose=False):
):
self.actionDir = actionDir
if actionDir and not os.path.exists(self.actionDir):
os.makedirs(self.actionDir)
self.episodeDir = episodeDir
if episodeDir and not os.path.exists(self.episodeDir):
os.makedirs(self.episodeDir)
self.mergeDir = mergeDir
if mergeDir and not os.path.exists(self.mergeDir):
os.makedirs(self.mergeDir)
self.use_pickle = use_pickle
self.missing_only = missing_only
# Test Env folder Paths
self.test_env_folder = test_env_folder
self.video_generation_envs = video_generation_envs
self.env_file_paths = self.get_env_filepaths()
random.shuffle(self.env_file_paths)
if shuffle:
random.shuffle(self.env_file_paths)
print(self.env_file_paths)
# Shuffle all the env_file_paths for more exciting videos
# and for more uniform time progression
......@@ -109,6 +133,10 @@ class FlatlandRemoteEvaluationService:
self.verbose = verbose
self.report = report
# Use a state to swallow and ignore any steps after an env times out.
# this should be reset to False after env reset() to get the next env.
self.state_env_timed_out = False
self.result_output_path = result_output_path
# Communication Protocol Related vars
......@@ -119,6 +147,12 @@ class FlatlandRemoteEvaluationService:
self.service_id
)
self.error_channel = "{}::{}::errors".format(
self.namespace,
self.service_id
)
# Message Broker related vars
self.remote_host = remote_host
self.remote_port = remote_port
......@@ -248,6 +282,16 @@ class FlatlandRemoteEvaluationService:
x, self.test_env_folder
) for x in env_paths])
# if requested, only generate actions for those envs which don't already have them
if self.mergeDir and self.missing_only:
existing_paths = (itertools.chain.from_iterable(
[ glob.glob(os.path.join(self.mergeDir, f"envs/*.{ext}"))
for ext in ["pkl", "mpk"] ]))
existing_paths = [os.path.relpath(sPath, self.mergeDir) for sPath in existing_paths]
env_paths = sorted(set(env_paths) - set(existing_paths))
return env_paths
def instantiate_evaluation_metadata(self):
......@@ -417,23 +461,37 @@ class FlatlandRemoteEvaluationService:
command = _redis.brpop(command_channel)[1]
return command
try:
#try:
if True:
_redis = self.get_redis_connection()
command = _get_next_command(self.command_channel, _redis)
if self.verbose or self.report:
print("Command Service: ", command)
except timeout_decorator.timeout_decorator.TimeoutError:
raise Exception(
"Timeout of {}s in step {} of simulation {}".format(
COMMAND_TIMEOUT,
self.current_step,
self.simulation_count
))
command = msgpack.unpackb(
command,
object_hook=m.decode,
encoding="utf8"
)
#except timeout_decorator.timeout_decorator.TimeoutError:
#raise Exception(
# "Timeout of {}s in step {} of simulation {}".format(
# COMMAND_TIMEOUT,
# self.current_step,
# self.simulation_count
# ))
# print("Timeout of {}s in step {} of simulation {}".format(
# COMMAND_TIMEOUT,
# self.current_step,
# self.simulation_count
# ))
# return {"type":messages.FLATLAND_RL.ENV_STEP_TIMEOUT}
if self.use_pickle:
command = pickle.loads(command)
else:
command = msgpack.unpackb(
command,
object_hook=m.decode,
strict_map_key=False, # msgpack 1.0
encoding="utf8" # msgpack 1.0
)
if self.verbose:
print("Received Request : ", command)
......@@ -448,13 +506,34 @@ class FlatlandRemoteEvaluationService:
if self.verbose and not suppress_logs:
print("Responding with : ", _command_response)
_redis.rpush(
command_response_channel,
msgpack.packb(
if self.use_pickle:
sResponse = pickle.dumps(_command_response)
else:
sResponse = msgpack.packb(
_command_response,
default=m.encode,
use_bin_type=True)
)
_redis.rpush(command_response_channel, sResponse)
def send_error(self, error_dict, suppress_logs=False):
""" For out-of-band errors like timeouts,
where we do not have a command, so we have no response channel!
"""
_redis = self.get_redis_connection()
#command_response_channel = command['response_channel']
if self.verbose and not suppress_logs:
print("Responding with : ", error_dict)
if self.use_pickle:
sResponse = pickle.dumps(error_dict)
else:
sResponse = msgpack.packb(
error_dict,
default=m.encode,
use_bin_type=True)
_redis.rpush(self.error_channel, sResponse)
def handle_ping(self, command):
"""
......@@ -487,6 +566,10 @@ class FlatlandRemoteEvaluationService:
TODO: Add a high level summary of everything thats happening here.
"""
self.simulation_count += 1
# reset the timeout flag / state.
self.state_env_timed_out = False
if self.simulation_count < len(self.env_file_paths):
"""
There are still test envs left that are yet to be evaluated
......@@ -498,10 +581,12 @@ class FlatlandRemoteEvaluationService:
test_env_file_path
)
del self.env
self.env = RailEnv(width=1, height=1, rail_generator=rail_from_file(test_env_file_path),
self.env = RailEnv(width=1, height=1,
rail_generator=rail_from_file(test_env_file_path),
schedule_generator=schedule_from_file(test_env_file_path),
malfunction_generator_and_process_data=malfunction_from_file(test_env_file_path),
obs_builder_object=DummyObservationBuilder())
obs_builder_object=DummyObservationBuilder(),
record_steps=True)
if self.begin_simulation:
# If begin simulation has already been initialized
......@@ -577,12 +662,18 @@ class FlatlandRemoteEvaluationService:
self.evaluation_state["score"]["score_secondary"] = mean_reward
self.evaluation_state["meta"]["normalized_reward"] = mean_normalized_reward
self.handle_aicrowd_info_event(self.evaluation_state)
self.lActions = []
def handle_env_step(self, command):
"""
Handles a ENV_STEP command from the client
TODO: Add a high level summary of everything thats happening here.
"""
if self.state_env_timed_out:
print("ignoring step command after timeout")
return
_payload = command['payload']
if not self.env:
......@@ -623,6 +714,9 @@ class FlatlandRemoteEvaluationService:
self.env.get_num_agents()
)
# record the actions before checking for done
if self.actionDir is not None:
self.lActions.append(action)
if done["__all__"]:
# Compute percentage complete
complete = 0
......@@ -633,6 +727,15 @@ class FlatlandRemoteEvaluationService:
percentage_complete = complete * 1.0 / self.env.get_num_agents()
self.simulation_percentage_complete[-1] = percentage_complete
if self.actionDir is not None:
self.save_actions()
if self.episodeDir is not None:
self.save_episode()
if self.mergeDir is not None:
self.save_merged_env()
# Record Frame
if self.visualize:
"""
......@@ -654,6 +757,54 @@ class FlatlandRemoteEvaluationService:
))
self.record_frame_step += 1
def send_env_step_timeout(self, command):
print("handle_env_step_timeout")
error_dict = dict(
type=messages.FLATLAND_RL.ENV_STEP_TIMEOUT,
# payload probably unnecessary
payload=dict(
observation=False,
env_file_path=False,
info=False,
random_seed=False
))
self.send_error(error_dict)
def save_actions(self):
sfEnv = self.env_file_paths[self.simulation_count]
sfActions = self.actionDir + "/" + sfEnv.replace(".pkl", ".json")
print("env path: ", sfEnv, " sfActions:", sfActions)
if not os.path.exists(os.path.dirname(sfActions)):
os.makedirs(os.path.dirname(sfActions))
with open(sfActions, "w") as fOut:
json.dump(self.lActions, fOut)
self.lActions = []
def save_episode(self):
sfEnv = self.env_file_paths[self.simulation_count]
sfEpisode = self.episodeDir + "/" + sfEnv
print("env path: ", sfEnv, " sfEpisode:", sfEpisode)
RailEnvPersister.save_episode(self.env, sfEpisode)
#self.env.save_episode(sfEpisode)
def save_merged_env(self):
sfEnv = self.env_file_paths[self.simulation_count]
sfMergeEnv = self.mergeDir + "/" + sfEnv
if not os.path.exists(os.path.dirname(sfMergeEnv)):
os.makedirs(os.path.dirname(sfMergeEnv))
print("Input env path: ", sfEnv, " Merge File:", sfMergeEnv)
RailEnvPersister.save_episode(self.env, sfMergeEnv)
#self.env.save_episode(sfMergeEnv)
def handle_env_submit(self, command):
"""
Handles a ENV_SUBMIT command from the client
......@@ -841,7 +992,15 @@ class FlatlandRemoteEvaluationService:
print("Listening at : ", self.command_channel)
MESSAGE_QUEUE_LATENCY = []
while True:
command = self.get_next_command()
try:
command = self.get_next_command()
except timeout_decorator.timeout_decorator.TimeoutError:
if self.previous_command['type'] == messages.FLATLAND_RL.ENV_STEP:
self.send_env_step_timeout({"error":messages.FLATLAND_RL.ENV_STEP_TIMEOUT})
self.state_env_timed_out = True
continue
if "timestamp" in command.keys():
latency = time.time() - command["timestamp"]
MESSAGE_QUEUE_LATENCY.append(latency)
......@@ -885,15 +1044,27 @@ class FlatlandRemoteEvaluationService:
print("Overall Message Queue Latency : ", np.array(MESSAGE_QUEUE_LATENCY).mean())
self.handle_env_submit(command)
elif command['type'] == messages.FLATLAND_RL.ENV_STEP_TIMEOUT:
"""
ENV_STEP_TIMEOUT
The client took too long to give us the next command.
"""
print("client env_step timeout")
self.handle_env_step_timeout(command)
else:
_error = self._error_template(
"UNKNOWN_REQUEST:{}".format(
str(command)))
if self.verbose:
print("Responding with : ", _error)
self.report_error(
_error,
command['response_channel'])
if "response_channel" in command:
self.report_error(
_error,
command['response_channel'])
return _error
###########################################
# We keep a record of the previous command
......@@ -908,9 +1079,10 @@ class FlatlandRemoteEvaluationService:
except Exception as e:
print("Error : ", str(e))
print(traceback.format_exc())
self.report_error(
self._error_template(str(e)),
command['response_channel'])
if ("response_channel" in command):
self.report_error(
self._error_template(str(e)),
command['response_channel'])
return self._error_template(str(e))
......@@ -927,6 +1099,49 @@ if __name__ == "__main__":
default="../../../submission-scoring/Envs-Small",
help="Folder containing the files for the test envs",
required=False)
parser.add_argument('--actionDir',
dest='actionDir',
default=None,
help="deprecated - use mergeDir. Folder containing the files for the test envs",
required=False)
parser.add_argument('--episodeDir',
dest='episodeDir',
default=None,
help="deprecated - use mergeDir. Folder containing the files for the test envs",
required=False)
parser.add_argument('--mergeDir',
dest='mergeDir',
default=None,
help="Folder to store merged envs, actions, episodes.",
required=False)
parser.add_argument('--pickle',
default=False,
action="store_true",
help="use pickle instead of msgpack",
required=False)
parser.add_argument('--noShuffle',
default=False,
action="store_true",
help="don't shuffle the envs. Default is to shuffle.",
required=False)
parser.add_argument('--missingOnly',
default=False,
action="store_true",
help="only request the envs/actions which are missing",
required=False)
parser.add_argument('--verbose',
default=False,
action="store_true",
help="verbose debug messages",
required=False)
args = parser.parse_args()
test_folder = args.test_folder
......@@ -934,10 +1149,16 @@ if __name__ == "__main__":
grader = FlatlandRemoteEvaluationService(
test_env_folder=test_folder,
flatland_rl_service_id=args.service_id,
verbose=False,
verbose=args.verbose,
visualize=True,
video_generation_envs=["Test_0/Level_100.pkl"],
result_output_path="/tmp/output.csv"
result_output_path="/tmp/output.csv",
actionDir=args.actionDir,
episodeDir=args.episodeDir,
mergeDir=args.mergeDir,
use_pickle=args.pickle,
shuffle=not args.noShuffle,
missing_only=args.missingOnly,
)
result = grader.run()
if result['type'] == messages.FLATLAND_RL.ENV_SUBMIT_RESPONSE:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment