From b299b8c64d424d61f6f1ef4a32b2707a8129fac1 Mon Sep 17 00:00:00 2001
From: hagrid67 <jdhwatson@gmail.com>
Date: Wed, 8 Jul 2020 21:20:18 +0100
Subject: [PATCH] draft / initial commit with running test_eval_timeout.py

---
 flatland/evaluators/client.py   |  67 ++++++--
 flatland/evaluators/messages.py |   5 +
 flatland/evaluators/service.py  | 283 ++++++++++++++++++++++++++++----
 3 files changed, 310 insertions(+), 45 deletions(-)

diff --git a/flatland/evaluators/client.py b/flatland/evaluators/client.py
index 545d52a7..3b47cab2 100644
--- a/flatland/evaluators/client.py
+++ b/flatland/evaluators/client.py
@@ -7,6 +7,7 @@ import time
 
 import msgpack
 import msgpack_numpy as m
+import pickle
 import numpy as np
 import redis
 
@@ -45,8 +46,9 @@ class FlatlandRemoteClient(object):
                  remote_db=0,
                  remote_password=None,
                  test_envs_root=None,
-                 verbose=False):
-
+                 verbose=False,
+                 use_pickle=False):
+        self.use_pickle=use_pickle
         self.remote_host = remote_host
         self.remote_port = remote_port
         self.remote_db = remote_db
@@ -67,6 +69,11 @@ class FlatlandRemoteClient(object):
             self.namespace,
             self.service_id
         )
+
+        # for timeout messages sent out-of-band
+        self.error_channel = "{}::{}::errors".format(
+            self.namespace, self.service_id)
+
         if test_envs_root:
             self.test_envs_root = test_envs_root
         else:
@@ -84,6 +91,8 @@ class FlatlandRemoteClient(object):
         self.env_step_times = []
         self.stats = {}
 
+
+
     def update_running_stats(self, key, scalar):
         """
         Computes the running mean for certain params
@@ -147,9 +156,28 @@ class FlatlandRemoteClient(object):
         """
         if self.verbose:
             print("Request : ", _request)
+
+        # check for errors (essentially just timeouts, for now.)
+        error_bytes = _redis.rpop(self.error_channel)
+        if error_bytes is not None:
+            if self.use_pickle:
+                error_dict = pickle.loads(error_bytes)
+            else:
+                error_dict = msgpack.unpackb(
+                    error_bytes,
+                    object_hook=m.decode,
+                    strict_map_key=False,  # new for msgpack 1.0?
+                    encoding="utf8"  # remove for msgpack 1.0
+                    )
+            print("error received: ", error_dict)
+            raise StopAsyncIteration(error_dict["type"])
+
         # Push request in command_channels
         # Note: The patched msgpack supports numpy arrays
-        payload = msgpack.packb(_request, default=m.encode, use_bin_type=True)
+        if self.use_pickle:
+            payload = pickle.dumps(_request)
+        else:
+            payload = msgpack.packb(_request, default=m.encode, use_bin_type=True)
         _redis.lpush(self.command_channel, payload)
 
         if blocking:
@@ -157,10 +185,15 @@ class FlatlandRemoteClient(object):
             _response = _redis.blpop(_request['response_channel'])[1]
             if self.verbose:
                 print("Response : ", _response)
-            _response = msgpack.unpackb(
-                _response,
-                object_hook=m.decode,
-                encoding="utf8")
+            if self.use_pickle:
+                _response = pickle.loads(_response)
+            else:
+                _response = msgpack.unpackb(
+                    _response,
+                    object_hook=m.decode,
+                    strict_map_key=False,  # new for msgpack 1.0?
+                    encoding="utf8"  # remove for msgpack 1.0
+                    )
             if _response['type'] == messages.FLATLAND_RL.ERROR:
                 raise Exception(str(_response["payload"]))
             else:
@@ -266,6 +299,7 @@ class FlatlandRemoteClient(object):
 
         # Relay the action in a non-blocking way to the server
         # so that it can start doing an env.step on it in ~ parallel
+        # Note - this can throw a Timeout
         self._remote_request(_request, blocking=False)
 
         # Apply the action in the local env
@@ -348,13 +382,18 @@ if __name__ == "__main__":
         while True:
             action = my_controller(obs, remote_client.env)
             time_start = time.time()
-            observation, all_rewards, done, info = remote_client.env_step(action)
-            time_diff = time.time() - time_start
-            print("Step Time : ", time_diff)
-            if done['__all__']:
-                print("Current Episode : ", episode)
-                print("Episode Done")
-                print("Reward : ", sum(list(all_rewards.values())))
+
+            try:
+                observation, all_rewards, done, info = remote_client.env_step(action)
+                time_diff = time.time() - time_start
+                print("Step Time : ", time_diff)
+                if done['__all__']:
+                    print("Current Episode : ", episode)
+                    print("Episode Done")
+                    print("Reward : ", sum(list(all_rewards.values())))
+                    break
+            except StopAsyncIteration as err:
+                print("Timeout: ", err)
                 break
 
     print("Evaluation Complete...")
diff --git a/flatland/evaluators/messages.py b/flatland/evaluators/messages.py
index 35c8b372..e084dec6 100644
--- a/flatland/evaluators/messages.py
+++ b/flatland/evaluators/messages.py
@@ -7,11 +7,16 @@ class FLATLAND_RL:
 
     ENV_RESET = "FLATLAND_RL.ENV_RESET"
     ENV_RESET_RESPONSE = "FLATLAND_RL.ENV_RESET_RESPONSE"
+    ENV_RESET_TIMEOUT = "FLATLAND_RL.ENV_RESET_TIMEOUT"
 
     ENV_STEP = "FLATLAND_RL.ENV_STEP"
     ENV_STEP_RESPONSE = "FLATLAND_RL.ENV_STEP_RESPONSE"
+    ENV_STEP_TIMEOUT = "FLATLAND_RL.ENV_STEP_TIMEOUT"
 
     ENV_SUBMIT = "FLATLAND_RL.ENV_SUBMIT"
     ENV_SUBMIT_RESPONSE = "FLATLAND_RL.ENV_SUBMIT_RESPONSE"
 
     ERROR = "FLATLAND_RL.ERROR"
+
+
+
diff --git a/flatland/evaluators/service.py b/flatland/evaluators/service.py
index d658dc9b..21cbd684 100644
--- a/flatland/evaluators/service.py
+++ b/flatland/evaluators/service.py
@@ -7,10 +7,13 @@ import random
 import shutil
 import time
 import traceback
+import json
+import itertools
 
 import crowdai_api
 import msgpack
 import msgpack_numpy as m
+import pickle
 import numpy as np
 import pandas as pd
 import redis
@@ -26,6 +29,8 @@ from flatland.envs.schedule_generators import schedule_from_file
 from flatland.evaluators import aicrowd_helpers
 from flatland.evaluators import messages
 from flatland.utils.rendertools import RenderTool
+from flatland.envs.rail_env_utils import load_flatland_environment_from_file
+from flatland.envs.persistence import RailEnvPersister
 
 use_signals_in_timeout = True
 if os.name == 'nt':
@@ -92,14 +97,33 @@ class FlatlandRemoteEvaluationService:
                  visualize=False,
                  video_generation_envs=[],
                  report=None,
+                 verbose=False,
+                 actionDir=None,   
+                 episodeDir=None,  
+                 mergeDir=None,
+                 use_pickle=False,
+                 shuffle=True,
+                 missing_only=False,
                  result_output_path=None,
-                 verbose=False):
-
+                 ):
+
+        self.actionDir = actionDir
+        if actionDir and not os.path.exists(self.actionDir):
+            os.makedirs(self.actionDir)
+        self.episodeDir = episodeDir
+        if episodeDir and not os.path.exists(self.episodeDir):
+            os.makedirs(self.episodeDir)
+        self.mergeDir = mergeDir
+        if mergeDir and not os.path.exists(self.mergeDir):
+            os.makedirs(self.mergeDir)
+        self.use_pickle = use_pickle
+        self.missing_only = missing_only
         # Test Env folder Paths
         self.test_env_folder = test_env_folder
         self.video_generation_envs = video_generation_envs
         self.env_file_paths = self.get_env_filepaths()
-        random.shuffle(self.env_file_paths)
+        if shuffle:
+            random.shuffle(self.env_file_paths)
         print(self.env_file_paths)
         # Shuffle all the env_file_paths for more exciting videos
         # and for more uniform time progression
@@ -109,6 +133,10 @@ class FlatlandRemoteEvaluationService:
         self.verbose = verbose
         self.report = report
 
+        # Use a state to swallow and ignore any steps after an env times out.
+        # this should be reset to False after env reset() to get the next env.
+        self.state_env_timed_out = False
+
         self.result_output_path = result_output_path
 
         # Communication Protocol Related vars
@@ -119,6 +147,12 @@ class FlatlandRemoteEvaluationService:
             self.service_id
         )
 
+        self.error_channel = "{}::{}::errors".format(
+            self.namespace,
+            self.service_id
+        )
+
+
         # Message Broker related vars
         self.remote_host = remote_host
         self.remote_port = remote_port
@@ -248,6 +282,16 @@ class FlatlandRemoteEvaluationService:
             x, self.test_env_folder
         ) for x in env_paths])
 
+
+ 		# if requested, only generate actions for those envs which don't already have them
+        if self.mergeDir and self.missing_only:
+            existing_paths = (itertools.chain.from_iterable(
+                [ glob.glob(os.path.join(self.mergeDir, f"envs/*.{ext}")) 
+                    for ext in ["pkl", "mpk"] ]))
+            existing_paths = [os.path.relpath(sPath, self.mergeDir) for sPath in existing_paths]
+            env_paths = sorted(set(env_paths) - set(existing_paths))
+
+
         return env_paths
 
     def instantiate_evaluation_metadata(self):
@@ -417,23 +461,37 @@ class FlatlandRemoteEvaluationService:
             command = _redis.brpop(command_channel)[1]
             return command
 
-        try:
+        #try:
+        if True:
             _redis = self.get_redis_connection()
             command = _get_next_command(self.command_channel, _redis)
             if self.verbose or self.report:
                 print("Command Service: ", command)
-        except timeout_decorator.timeout_decorator.TimeoutError:
-            raise Exception(
-                "Timeout of {}s in step {} of simulation {}".format(
-                    COMMAND_TIMEOUT,
-                    self.current_step,
-                    self.simulation_count
-                ))
-        command = msgpack.unpackb(
-            command,
-            object_hook=m.decode,
-            encoding="utf8"
-        )
+        #except timeout_decorator.timeout_decorator.TimeoutError:
+            #raise Exception(
+            #    "Timeout of {}s in step {} of simulation {}".format(
+            #        COMMAND_TIMEOUT,
+            #        self.current_step,
+            #        self.simulation_count
+            #    ))
+        
+        #    print("Timeout of {}s in step {} of simulation {}".format(
+        #        COMMAND_TIMEOUT,
+        #        self.current_step,
+        #        self.simulation_count
+        #        ))
+        #    return {"type":messages.FLATLAND_RL.ENV_STEP_TIMEOUT}
+
+
+        if self.use_pickle:
+            command = pickle.loads(command)
+        else:
+            command = msgpack.unpackb(
+                command,
+                object_hook=m.decode,
+                strict_map_key=False,   # msgpack 1.0
+                encoding="utf8"        # msgpack 1.0
+            )
         if self.verbose:
             print("Received Request : ", command)
 
@@ -448,13 +506,34 @@ class FlatlandRemoteEvaluationService:
         if self.verbose and not suppress_logs:
             print("Responding with : ", _command_response)
 
-        _redis.rpush(
-            command_response_channel,
-            msgpack.packb(
+        if self.use_pickle:
+            sResponse = pickle.dumps(_command_response)
+        else:
+            sResponse = msgpack.packb(
                 _command_response,
                 default=m.encode,
                 use_bin_type=True)
-        )
+        _redis.rpush(command_response_channel, sResponse)
+
+    def send_error(self, error_dict, suppress_logs=False):
+        """ For out-of-band errors like timeouts, 
+            where we do not have a command, so we have no response channel!
+        """
+        _redis = self.get_redis_connection()
+        #command_response_channel = command['response_channel']
+
+        if self.verbose and not suppress_logs:
+            print("Responding with : ", error_dict)
+
+        if self.use_pickle:
+            sResponse = pickle.dumps(error_dict)
+        else:
+            sResponse = msgpack.packb(
+                error_dict,
+                default=m.encode,
+                use_bin_type=True)
+
+        _redis.rpush(self.error_channel, sResponse)
 
     def handle_ping(self, command):
         """
@@ -487,6 +566,10 @@ class FlatlandRemoteEvaluationService:
         TODO: Add a high level summary of everything thats happening here.
         """
         self.simulation_count += 1
+
+        # reset the timeout flag / state.
+        self.state_env_timed_out = False
+
         if self.simulation_count < len(self.env_file_paths):
             """
             There are still test envs left that are yet to be evaluated 
@@ -498,10 +581,12 @@ class FlatlandRemoteEvaluationService:
                 test_env_file_path
             )
             del self.env
-            self.env = RailEnv(width=1, height=1, rail_generator=rail_from_file(test_env_file_path),
+            self.env = RailEnv(width=1, height=1,
+                               rail_generator=rail_from_file(test_env_file_path),
                                schedule_generator=schedule_from_file(test_env_file_path),
                                malfunction_generator_and_process_data=malfunction_from_file(test_env_file_path),
-                               obs_builder_object=DummyObservationBuilder())
+                               obs_builder_object=DummyObservationBuilder(),
+                               record_steps=True)
 
             if self.begin_simulation:
                 # If begin simulation has already been initialized
@@ -577,12 +662,18 @@ class FlatlandRemoteEvaluationService:
         self.evaluation_state["score"]["score_secondary"] = mean_reward
         self.evaluation_state["meta"]["normalized_reward"] = mean_normalized_reward
         self.handle_aicrowd_info_event(self.evaluation_state)
+        self.lActions = []
 
     def handle_env_step(self, command):
         """
         Handles a ENV_STEP command from the client
         TODO: Add a high level summary of everything thats happening here.
         """
+
+        if self.state_env_timed_out:
+            print("ignoring step command after timeout")
+            return
+
         _payload = command['payload']
 
         if not self.env:
@@ -623,6 +714,9 @@ class FlatlandRemoteEvaluationService:
                 self.env.get_num_agents()
             )
 
+        # record the actions before checking for done
+        if self.actionDir is not None:
+            self.lActions.append(action)
         if done["__all__"]:
             # Compute percentage complete
             complete = 0
@@ -633,6 +727,15 @@ class FlatlandRemoteEvaluationService:
             percentage_complete = complete * 1.0 / self.env.get_num_agents()
             self.simulation_percentage_complete[-1] = percentage_complete
 
+            if self.actionDir is not None:
+                self.save_actions()
+            
+            if self.episodeDir is not None:
+                self.save_episode()
+
+            if self.mergeDir is not None:
+                self.save_merged_env()
+
         # Record Frame
         if self.visualize:
             """
@@ -654,6 +757,54 @@ class FlatlandRemoteEvaluationService:
                     ))
                 self.record_frame_step += 1
 
+    def send_env_step_timeout(self, command):
+        print("handle_env_step_timeout")
+        error_dict = dict(
+            type=messages.FLATLAND_RL.ENV_STEP_TIMEOUT,
+
+            # payload probably unnecessary
+            payload=dict(
+                observation=False,
+                env_file_path=False,
+                info=False,
+                random_seed=False
+            ))
+        
+        self.send_error(error_dict)
+
+    def save_actions(self):
+        sfEnv = self.env_file_paths[self.simulation_count]
+        
+        sfActions = self.actionDir + "/" + sfEnv.replace(".pkl", ".json")
+
+        print("env path: ", sfEnv, " sfActions:", sfActions)
+
+        if not os.path.exists(os.path.dirname(sfActions)):
+            os.makedirs(os.path.dirname(sfActions))
+        
+        with open(sfActions, "w") as fOut:
+            json.dump(self.lActions, fOut)
+
+        self.lActions = []
+    
+    def save_episode(self):
+        sfEnv = self.env_file_paths[self.simulation_count]
+        sfEpisode = self.episodeDir + "/" + sfEnv
+        print("env path: ", sfEnv, " sfEpisode:", sfEpisode)
+        RailEnvPersister.save_episode(self.env, sfEpisode)
+        #self.env.save_episode(sfEpisode)
+    
+    def save_merged_env(self):
+        sfEnv = self.env_file_paths[self.simulation_count]
+        sfMergeEnv = self.mergeDir + "/" + sfEnv
+
+        if not os.path.exists(os.path.dirname(sfMergeEnv)):
+            os.makedirs(os.path.dirname(sfMergeEnv))
+
+        print("Input env path: ", sfEnv, " Merge File:", sfMergeEnv)
+        RailEnvPersister.save_episode(self.env, sfMergeEnv)
+        #self.env.save_episode(sfMergeEnv)
+
     def handle_env_submit(self, command):
         """
         Handles a ENV_SUBMIT command from the client
@@ -841,7 +992,15 @@ class FlatlandRemoteEvaluationService:
         print("Listening at : ", self.command_channel)
         MESSAGE_QUEUE_LATENCY = []
         while True:
-            command = self.get_next_command()
+
+            try:
+                command = self.get_next_command()
+            except timeout_decorator.timeout_decorator.TimeoutError:
+                if self.previous_command['type'] == messages.FLATLAND_RL.ENV_STEP:
+                    self.send_env_step_timeout({"error":messages.FLATLAND_RL.ENV_STEP_TIMEOUT})
+                    self.state_env_timed_out = True
+                    continue
+
             if "timestamp" in command.keys():
                 latency = time.time() - command["timestamp"]
                 MESSAGE_QUEUE_LATENCY.append(latency)
@@ -885,15 +1044,27 @@ class FlatlandRemoteEvaluationService:
 
                     print("Overall Message Queue Latency : ", np.array(MESSAGE_QUEUE_LATENCY).mean())
                     self.handle_env_submit(command)
+                elif command['type'] == messages.FLATLAND_RL.ENV_STEP_TIMEOUT:
+                    """
+                        ENV_STEP_TIMEOUT
+
+                        The client took too long to give us the next command.
+
+                    """
+
+                    print("client env_step timeout")
+                    self.handle_env_step_timeout(command)
+                
                 else:
                     _error = self._error_template(
                         "UNKNOWN_REQUEST:{}".format(
                             str(command)))
                     if self.verbose:
                         print("Responding with : ", _error)
-                    self.report_error(
-                        _error,
-                        command['response_channel'])
+                    if "response_channel" in command:
+                        self.report_error(
+                            _error,
+                            command['response_channel'])
                     return _error
                 ###########################################
                 # We keep a record of the previous command
@@ -908,9 +1079,10 @@ class FlatlandRemoteEvaluationService:
             except Exception as e:
                 print("Error : ", str(e))
                 print(traceback.format_exc())
-                self.report_error(
-                    self._error_template(str(e)),
-                    command['response_channel'])
+                if ("response_channel" in command):
+                    self.report_error(
+                        self._error_template(str(e)),
+                        command['response_channel'])
                 return self._error_template(str(e))
 
 
@@ -927,6 +1099,49 @@ if __name__ == "__main__":
                         default="../../../submission-scoring/Envs-Small",
                         help="Folder containing the files for the test envs",
                         required=False)
+
+    parser.add_argument('--actionDir',
+                        dest='actionDir',
+                        default=None,
+                        help="deprecated - use mergeDir.  Folder containing the files for the test envs",
+                        required=False)
+    
+    parser.add_argument('--episodeDir',
+                        dest='episodeDir',
+                        default=None,
+                        help="deprecated - use mergeDir.   Folder containing the files for the test envs",
+                        required=False)
+    
+    parser.add_argument('--mergeDir',
+                        dest='mergeDir',
+                        default=None,
+                        help="Folder to store merged envs, actions, episodes.",
+                        required=False)
+
+    parser.add_argument('--pickle',
+                        default=False,
+                        action="store_true",
+                        help="use pickle instead of msgpack",
+                        required=False)
+
+    parser.add_argument('--noShuffle',
+                        default=False,
+                        action="store_true",
+                        help="don't shuffle the envs.  Default is to shuffle.",
+                        required=False)
+
+    parser.add_argument('--missingOnly',
+                        default=False,
+                        action="store_true",
+                        help="only request the envs/actions which are missing",
+                        required=False)
+
+
+    parser.add_argument('--verbose',
+                        default=False,
+                        action="store_true",
+                        help="verbose debug messages",
+                        required=False)
     args = parser.parse_args()
 
     test_folder = args.test_folder
@@ -934,10 +1149,16 @@ if __name__ == "__main__":
     grader = FlatlandRemoteEvaluationService(
         test_env_folder=test_folder,
         flatland_rl_service_id=args.service_id,
-        verbose=False,
+        verbose=args.verbose,
         visualize=True,
         video_generation_envs=["Test_0/Level_100.pkl"],
-        result_output_path="/tmp/output.csv"
+        result_output_path="/tmp/output.csv",
+        actionDir=args.actionDir,
+        episodeDir=args.episodeDir,
+        mergeDir=args.mergeDir,
+        use_pickle=args.pickle,
+        shuffle=not args.noShuffle,
+        missing_only=args.missingOnly,
     )
     result = grader.run()
     if result['type'] == messages.FLATLAND_RL.ENV_SUBMIT_RESPONSE:
-- 
GitLab