diff --git a/flatland/evaluators/client.py b/flatland/evaluators/client.py index 7ae3752606212c036537a88ad7ea49327c9a9265..f49e35dba2703ddf377ae7975741ae6ab4722124 100644 --- a/flatland/evaluators/client.py +++ b/flatland/evaluators/client.py @@ -90,6 +90,7 @@ class FlatlandRemoteClient(object): self.ping_pong() self.env_step_times = [] + self.divergence_computation_times = [] def get_redis_connection(self): return self.redis_conn @@ -248,6 +249,7 @@ class FlatlandRemoteClient(object): self.env.step(action) self.env_step_times.append(time.time() - time_start) + time_start = time.time() if self.verbose: print(local_reward) if not are_dicts_equal(remote_reward, local_reward): @@ -256,6 +258,7 @@ class FlatlandRemoteClient(object): if not are_dicts_equal(remote_done, local_done): print("Remote Done : ", remote_done, "Local Done : ", local_done) raise Exception("local and remote `done` are diverging") + self.divergence_computation_times.append(time.time() - time_start) # Return local_observation instead of remote_observation # as the remote_observation is build using a dummy observation @@ -267,6 +270,7 @@ class FlatlandRemoteClient(object): def submit(self): print("Mean Env Step internal : ", np.array(self.env_step_times).mean()) + print("Mean Divergence Computation Time : ", np.array(self.divergence_computation_times).mean()) _request = {} _request['type'] = messages.FLATLAND_RL.ENV_SUBMIT _request['payload'] = {}