Skip to content
Snippets Groups Projects
Commit 9dc4b86e authored by spmohanty's avatar spmohanty
Browse files

Add mean internal-env-step time statistics

parent fa75e4c1
No related branches found
No related tags found
No related merge requests found
......@@ -89,6 +89,8 @@ class FlatlandRemoteClient(object):
self.env = None
self.ping_pong()
self.env_step_times = []
def get_redis_connection(self):
return self.redis_conn
......@@ -102,7 +104,7 @@ class FlatlandRemoteClient(object):
random_hash)
return response_channel
def _blocking_request(self, _request):
def _remote_request(self, _request, blocking=True):
"""
request:
-command_type
......@@ -128,18 +130,20 @@ class FlatlandRemoteClient(object):
# Note: The patched msgpack supports numpy arrays
payload = msgpack.packb(_request, default=m.encode, use_bin_type=True)
_redis.lpush(self.command_channel, payload)
# Wait with a blocking pop for the response
_response = _redis.blpop(_request['response_channel'])[1]
if self.verbose:
print("Response : ", _response)
_response = msgpack.unpackb(
_response,
object_hook=m.decode,
encoding="utf8")
if _response['type'] == messages.FLATLAND_RL.ERROR:
raise Exception(str(_response["payload"]))
else:
return _response
if blocking:
# Wait with a blocking pop for the response
_response = _redis.blpop(_request['response_channel'])[1]
if self.verbose:
print("Response : ", _response)
_response = msgpack.unpackb(
_response,
object_hook=m.decode,
encoding="utf8")
if _response['type'] == messages.FLATLAND_RL.ERROR:
raise Exception(str(_response["payload"]))
else:
return _response
def ping_pong(self):
"""
......@@ -153,7 +157,7 @@ class FlatlandRemoteClient(object):
_request['payload'] = {
"version": flatland.__version__
}
_response = self._blocking_request(_request)
_response = self._remote_request(_request)
if _response['type'] != messages.FLATLAND_RL.PONG:
raise Exception(
"Unable to perform handshake with the evaluation service. \
......@@ -171,7 +175,7 @@ class FlatlandRemoteClient(object):
_request = {}
_request['type'] = messages.FLATLAND_RL.ENV_CREATE
_request['payload'] = {}
_response = self._blocking_request(_request)
_response = self._remote_request(_request)
observation = _response['payload']['observation']
info = _response['payload']['info']
random_seed = _response['payload']['random_seed']
......@@ -228,39 +232,44 @@ class FlatlandRemoteClient(object):
_request['type'] = messages.FLATLAND_RL.ENV_STEP
_request['payload'] = {}
_request['payload']['action'] = action
_response = self._blocking_request(_request)
_payload = _response['payload']
_response = self._remote_request(_request, blocking=False)
# _payload = _response['payload']
# remote_observation = _payload['observation'] # noqa
remote_reward = _payload['reward']
remote_done = _payload['done']
remote_info = _payload['info']
# # remote_observation = _payload['observation'] # noqa
# remote_reward = _payload['reward']
# remote_done = _payload['done']
# remote_info = _payload['info']
# Replicate the action in the local env
time_start = time.time()
local_observation, local_reward, local_done, local_info = \
self.env.step(action)
self.env_step_times.append(time.time() - time_start)
if self.verbose:
print(local_reward)
if not are_dicts_equal(remote_reward, local_reward):
print("Remote Reward : ", remote_reward, "Local Reward : ", local_reward)
raise Exception("local and remote `reward` are diverging")
if not are_dicts_equal(remote_done, local_done):
print("Remote Done : ", remote_done, "Local Done : ", local_done)
raise Exception("local and remote `done` are diverging")
# if not are_dicts_equal(remote_reward, local_reward):
# print("Remote Reward : ", remote_reward, "Local Reward : ", local_reward)
# raise Exception("local and remote `reward` are diverging")
# if not are_dicts_equal(remote_done, local_done):
# print("Remote Done : ", remote_done, "Local Done : ", local_done)
# raise Exception("local and remote `done` are diverging")
# Return local_observation instead of remote_observation
# as the remote_observation is build using a dummy observation
# builder
# We return the remote rewards and done as they are the
# once used by the evaluator
return [local_observation, remote_reward, remote_done, remote_info]
# return [local_observation, remote_reward, remote_done, remote_info]
return [local_observation, local_reward, local_done, local_info]
def submit(self):
print("Mean Env Step internal : ", np.array(self.env_step_times).mean())
_request = {}
_request['type'] = messages.FLATLAND_RL.ENV_SUBMIT
_request['payload'] = {}
_response = self._blocking_request(_request)
_response = self._remote_request(_request)
if os.getenv("AICROWD_BLOCKING_SUBMIT"):
"""
If the submission is supposed to happen as a blocking submit,
......
......@@ -134,6 +134,7 @@ class FlatlandRemoteEvaluationService:
self.simulation_percentage_complete = []
self.simulation_steps = []
self.simulation_times = []
self.env_step_times = []
self.begin_simulation = False
self.current_step = 0
self.visualize = visualize
......@@ -401,7 +402,9 @@ class FlatlandRemoteEvaluationService:
has done['__all__']==True")
action = _payload['action']
time_start = time.time()
_observation, all_rewards, done, info = self.env.step(action)
self.env_step_times.append(time.time() - time_start)
cumulative_reward = np.sum(list(all_rewards.values()))
self.simulation_rewards[-1] += cumulative_reward
......@@ -465,6 +468,8 @@ class FlatlandRemoteEvaluationService:
"""
_payload = command['payload']
print("Mean Env Step Time : ", np.array(self.env_step_times).mean())
# Register simulation time of the last episode
self.simulation_times.append(time.time() - self.begin_simulation)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment