Skip to content
Snippets Groups Projects
Commit 9dc4b86e authored by spmohanty's avatar spmohanty
Browse files

Add mean internal-env-step time statistics

parent fa75e4c1
No related branches found
No related tags found
No related merge requests found
...@@ -89,6 +89,8 @@ class FlatlandRemoteClient(object): ...@@ -89,6 +89,8 @@ class FlatlandRemoteClient(object):
self.env = None self.env = None
self.ping_pong() self.ping_pong()
self.env_step_times = []
def get_redis_connection(self): def get_redis_connection(self):
return self.redis_conn return self.redis_conn
...@@ -102,7 +104,7 @@ class FlatlandRemoteClient(object): ...@@ -102,7 +104,7 @@ class FlatlandRemoteClient(object):
random_hash) random_hash)
return response_channel return response_channel
def _blocking_request(self, _request): def _remote_request(self, _request, blocking=True):
""" """
request: request:
-command_type -command_type
...@@ -128,18 +130,20 @@ class FlatlandRemoteClient(object): ...@@ -128,18 +130,20 @@ class FlatlandRemoteClient(object):
# Note: The patched msgpack supports numpy arrays # Note: The patched msgpack supports numpy arrays
payload = msgpack.packb(_request, default=m.encode, use_bin_type=True) payload = msgpack.packb(_request, default=m.encode, use_bin_type=True)
_redis.lpush(self.command_channel, payload) _redis.lpush(self.command_channel, payload)
# Wait with a blocking pop for the response
_response = _redis.blpop(_request['response_channel'])[1] if blocking:
if self.verbose: # Wait with a blocking pop for the response
print("Response : ", _response) _response = _redis.blpop(_request['response_channel'])[1]
_response = msgpack.unpackb( if self.verbose:
_response, print("Response : ", _response)
object_hook=m.decode, _response = msgpack.unpackb(
encoding="utf8") _response,
if _response['type'] == messages.FLATLAND_RL.ERROR: object_hook=m.decode,
raise Exception(str(_response["payload"])) encoding="utf8")
else: if _response['type'] == messages.FLATLAND_RL.ERROR:
return _response raise Exception(str(_response["payload"]))
else:
return _response
def ping_pong(self): def ping_pong(self):
""" """
...@@ -153,7 +157,7 @@ class FlatlandRemoteClient(object): ...@@ -153,7 +157,7 @@ class FlatlandRemoteClient(object):
_request['payload'] = { _request['payload'] = {
"version": flatland.__version__ "version": flatland.__version__
} }
_response = self._blocking_request(_request) _response = self._remote_request(_request)
if _response['type'] != messages.FLATLAND_RL.PONG: if _response['type'] != messages.FLATLAND_RL.PONG:
raise Exception( raise Exception(
"Unable to perform handshake with the evaluation service. \ "Unable to perform handshake with the evaluation service. \
...@@ -171,7 +175,7 @@ class FlatlandRemoteClient(object): ...@@ -171,7 +175,7 @@ class FlatlandRemoteClient(object):
_request = {} _request = {}
_request['type'] = messages.FLATLAND_RL.ENV_CREATE _request['type'] = messages.FLATLAND_RL.ENV_CREATE
_request['payload'] = {} _request['payload'] = {}
_response = self._blocking_request(_request) _response = self._remote_request(_request)
observation = _response['payload']['observation'] observation = _response['payload']['observation']
info = _response['payload']['info'] info = _response['payload']['info']
random_seed = _response['payload']['random_seed'] random_seed = _response['payload']['random_seed']
...@@ -228,39 +232,44 @@ class FlatlandRemoteClient(object): ...@@ -228,39 +232,44 @@ class FlatlandRemoteClient(object):
_request['type'] = messages.FLATLAND_RL.ENV_STEP _request['type'] = messages.FLATLAND_RL.ENV_STEP
_request['payload'] = {} _request['payload'] = {}
_request['payload']['action'] = action _request['payload']['action'] = action
_response = self._blocking_request(_request)
_payload = _response['payload'] _response = self._remote_request(_request, blocking=False)
# _payload = _response['payload']
# remote_observation = _payload['observation'] # noqa # # remote_observation = _payload['observation'] # noqa
remote_reward = _payload['reward'] # remote_reward = _payload['reward']
remote_done = _payload['done'] # remote_done = _payload['done']
remote_info = _payload['info'] # remote_info = _payload['info']
# Replicate the action in the local env # Replicate the action in the local env
time_start = time.time()
local_observation, local_reward, local_done, local_info = \ local_observation, local_reward, local_done, local_info = \
self.env.step(action) self.env.step(action)
self.env_step_times.append(time.time() - time_start)
if self.verbose: if self.verbose:
print(local_reward) print(local_reward)
if not are_dicts_equal(remote_reward, local_reward): # if not are_dicts_equal(remote_reward, local_reward):
print("Remote Reward : ", remote_reward, "Local Reward : ", local_reward) # print("Remote Reward : ", remote_reward, "Local Reward : ", local_reward)
raise Exception("local and remote `reward` are diverging") # raise Exception("local and remote `reward` are diverging")
if not are_dicts_equal(remote_done, local_done): # if not are_dicts_equal(remote_done, local_done):
print("Remote Done : ", remote_done, "Local Done : ", local_done) # print("Remote Done : ", remote_done, "Local Done : ", local_done)
raise Exception("local and remote `done` are diverging") # raise Exception("local and remote `done` are diverging")
# Return local_observation instead of remote_observation # Return local_observation instead of remote_observation
# as the remote_observation is build using a dummy observation # as the remote_observation is build using a dummy observation
# builder # builder
# We return the remote rewards and done as they are the # We return the remote rewards and done as they are the
# once used by the evaluator # once used by the evaluator
return [local_observation, remote_reward, remote_done, remote_info] # return [local_observation, remote_reward, remote_done, remote_info]
return [local_observation, local_reward, local_done, local_info]
def submit(self): def submit(self):
print("Mean Env Step internal : ", np.array(self.env_step_times).mean())
_request = {} _request = {}
_request['type'] = messages.FLATLAND_RL.ENV_SUBMIT _request['type'] = messages.FLATLAND_RL.ENV_SUBMIT
_request['payload'] = {} _request['payload'] = {}
_response = self._blocking_request(_request) _response = self._remote_request(_request)
if os.getenv("AICROWD_BLOCKING_SUBMIT"): if os.getenv("AICROWD_BLOCKING_SUBMIT"):
""" """
If the submission is supposed to happen as a blocking submit, If the submission is supposed to happen as a blocking submit,
......
...@@ -134,6 +134,7 @@ class FlatlandRemoteEvaluationService: ...@@ -134,6 +134,7 @@ class FlatlandRemoteEvaluationService:
self.simulation_percentage_complete = [] self.simulation_percentage_complete = []
self.simulation_steps = [] self.simulation_steps = []
self.simulation_times = [] self.simulation_times = []
self.env_step_times = []
self.begin_simulation = False self.begin_simulation = False
self.current_step = 0 self.current_step = 0
self.visualize = visualize self.visualize = visualize
...@@ -401,7 +402,9 @@ class FlatlandRemoteEvaluationService: ...@@ -401,7 +402,9 @@ class FlatlandRemoteEvaluationService:
has done['__all__']==True") has done['__all__']==True")
action = _payload['action'] action = _payload['action']
time_start = time.time()
_observation, all_rewards, done, info = self.env.step(action) _observation, all_rewards, done, info = self.env.step(action)
self.env_step_times.append(time.time() - time_start)
cumulative_reward = np.sum(list(all_rewards.values())) cumulative_reward = np.sum(list(all_rewards.values()))
self.simulation_rewards[-1] += cumulative_reward self.simulation_rewards[-1] += cumulative_reward
...@@ -465,6 +468,8 @@ class FlatlandRemoteEvaluationService: ...@@ -465,6 +468,8 @@ class FlatlandRemoteEvaluationService:
""" """
_payload = command['payload'] _payload = command['payload']
print("Mean Env Step Time : ", np.array(self.env_step_times).mean())
# Register simulation time of the last episode # Register simulation time of the last episode
self.simulation_times.append(time.time() - self.begin_simulation) self.simulation_times.append(time.time() - self.begin_simulation)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment