From 4e7241f76c14a28de7f37cf5adcdb715620e995e Mon Sep 17 00:00:00 2001
From: "S.P. Mohanty" <spmohanty91@gmail.com>
Date: Wed, 23 Oct 2019 14:43:40 +0200
Subject: [PATCH] Add mean divergence computation time stats

---
 flatland/evaluators/client.py | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/flatland/evaluators/client.py b/flatland/evaluators/client.py
index 7ae37526..f49e35db 100644
--- a/flatland/evaluators/client.py
+++ b/flatland/evaluators/client.py
@@ -90,6 +90,7 @@ class FlatlandRemoteClient(object):
         self.ping_pong()
 
         self.env_step_times = []
+        self.divergence_computation_times = []
 
     def get_redis_connection(self):
         return self.redis_conn
@@ -248,6 +249,7 @@ class FlatlandRemoteClient(object):
             self.env.step(action)
         self.env_step_times.append(time.time() - time_start)
 
+        time_start = time.time()
         if self.verbose:
             print(local_reward)
         if not are_dicts_equal(remote_reward, local_reward):
@@ -256,6 +258,7 @@ class FlatlandRemoteClient(object):
         if not are_dicts_equal(remote_done, local_done):
             print("Remote Done : ", remote_done, "Local Done : ", local_done)
             raise Exception("local and remote `done` are diverging")
+        self.divergence_computation_times.append(time.time() - time_start)
 
         # Return local_observation instead of remote_observation
         # as the remote_observation is build using a dummy observation
@@ -267,6 +270,7 @@ class FlatlandRemoteClient(object):
 
     def submit(self):
         print("Mean Env Step internal : ", np.array(self.env_step_times).mean())
+        print("Mean Divergence Computation Time : ", np.array(self.divergence_computation_times).mean())
         _request = {}
         _request['type'] = messages.FLATLAND_RL.ENV_SUBMIT
         _request['payload'] = {}
-- 
GitLab