diff --git a/flatland/evaluators/client.py b/flatland/evaluators/client.py index ee6c9f6eb67eaa97dfa8315cb76021ff392e25ee..c095db9edc0aa7917f7b7e6e880e92a4d517356e 100644 --- a/flatland/evaluators/client.py +++ b/flatland/evaluators/client.py @@ -11,7 +11,6 @@ m.patch() import hashlib import random from flatland.evaluators import messages -from flatland.evaluators.utils import get_all_env_pickle_files from flatland.envs.rail_env import RailEnv from flatland.envs.generators import rail_from_file @@ -110,7 +109,7 @@ class FlatlandRemoteClient(object): if self.verbose: print("Response : ", _response) _response = msgpack.unpackb(_response, object_hook=m.decode, encoding="utf8") if _response['type'] == messages.FLATLAND_RL.ERROR: - raise Exception(str(_response)) + raise Exception(str(_response["payload"])) else: return _response @@ -195,12 +194,27 @@ if __name__ == "__main__": _action[_idx] = np.random.randint(0, 5) return _action - obs = True episode = 0 + obs = True while obs: obs = env_client.env_create() + if not obs: + break print("Episode : {}".format(episode)) - print(obs) - print(env_client.env.width) - print(env_client.env.height) episode += 1 + + print(env_client.env.dones['__all__']) + + while True: + action = my_controller(obs, env_client.env) + observation, all_rewards, done, info = env_client.env_step(action) + if done['__all__']: + print("Current Episode : ", episode) + print("Episode Done") + print("Reward : ", sum(list(all_rewards.values()))) + break + + print("Evaluation Complete...") + print(env_client.submit()) + + diff --git a/flatland/evaluators/service.py b/flatland/evaluators/service.py index 2c4f60f25f619c806cdf7c85a52a00d7e0b84673..d6c11c8c5eabd3e4b50d19ad48ccf803451b1632 100644 --- a/flatland/evaluators/service.py +++ b/flatland/evaluators/service.py @@ -21,6 +21,8 @@ import os import timeout_decorator import time +import traceback + ######################################################## # CONSTANTS ######################################################## @@ -69,6 +71,8 @@ class FlatlandRemoteEvaluationService: self.reward = 0 self.simulation_count = 0 self.simualation_rewards = [] + self.simulation_percentage_complete = [] + self.simulation_steps = [] self.simulation_times = [] self.begin_simulation = False self.current_step = 0 @@ -205,6 +209,9 @@ class FlatlandRemoteEvaluationService: self.begin_simulation = time.time() self.simualation_rewards.append(0) + self.simulation_percentage_complete.append(0) + self.simulation_steps.append(0) + self.current_step = 0 _observation = self.env.reset() @@ -227,7 +234,7 @@ class FlatlandRemoteEvaluationService: All test env evaluations are complete """ _command_response = {} - _command_response['type'] = messages.FLATLAND_RL.ENV_RESET_RESPONSE + _command_response['type'] = messages.FLATLAND_RL.ENV_CREATE_RESPONSE _command_response['payload'] = {} _command_response['payload']['observation'] = False _command_response['payload']['env_file_path'] = False @@ -240,6 +247,90 @@ class FlatlandRemoteEvaluationService: use_bin_type=True) ) + def handle_env_step(self, command): + _redis = self.get_redis_connection() + command_response_channel = command['response_channel'] + _payload = command['payload'] + + if self.env.dones['__all__']: + raise Exception("Client attempted to perform an action on an Env which has done['__all__']==True") + + action = _payload['action'] + _observation, all_rewards, done, info = self.env.step(action) + + cumulative_reward = np.sum(list(all_rewards.values())) + self.simualation_rewards[-1] += cumulative_reward + self.simulation_steps[-1] += 1 + + if done["__all__"]: + # Compute percentage complete + complete = 0 + for i_agent in range(self.env.get_num_agents()): + agent = self.env.agents[i_agent] + if agent.position == agent.target: + complete += 1 + percentage_complete = complete * 1.0 / self.env.get_num_agents() + self.simulation_percentage_complete[-1] = percentage_complete + + # Build and send response + _command_response = {} + _command_response['type'] = messages.FLATLAND_RL.ENV_STEP_RESPONSE + _command_response['payload'] = {} + _command_response['payload']['observation'] = _observation + _command_response['payload']['reward'] = all_rewards + _command_response['payload']['done'] = done + _command_response['payload']['info'] = info + if self.verbose: + # print("Responding with : ", _command_response) + print("Current Step : ", self.simulation_steps[-1]) + _redis.rpush( + command_response_channel, + msgpack.packb( + _command_response, + default=m.encode, + use_bin_type=True) + ) + + def handle_env_submit(self, command): + _redis = self.get_redis_connection() + command_response_channel = command['response_channel'] + _payload = command['payload'] + + # Register simulation time of the last episode + self.simulation_times.append(time.time()-self.begin_simulation) + + _response = {} + _response['type'] = messages.FLATLAND_RL.ENV_SUBMIT_RESPONSE + _payload = {} + _payload['mean_reward'] = np.mean(self.simualation_rewards) + _payload['mean_percentage_complete'] = \ + np.mean(self.simulation_percentage_complete) + + _response['payload'] = _payload + if self.verbose: + print("Responding with : ", _response) + print("Registering Env Submit call") + _redis.rpush( + command_response_channel, + msgpack.packb( + _response, + default=m.encode, + use_bin_type=True) + ) + + def report_error(self, error_message, command_response_channel): + _redis = self.get_redis_connection() + _response = {} + _response['type'] = messages.FLATLAND_RL.ERROR + _response['payload'] = error_message + _redis.rpush( + command_response_channel, + msgpack.packb( + _response, + default=m.encode, + use_bin_type=True) + ) + def run(self): print("Listening for commands at : ", self.command_channel) @@ -269,117 +360,36 @@ class FlatlandRemoteEvaluationService: Respond with an internal _env object """ self.handle_env_create(command) - elif command['type'] == messages.FLATLAND_RL.ENV_RESET: - """ - ENV_RESET - - Respond with observation from next simulation or - False if no simulations are left - """ - self.simulation_count += 1 - if self.begin_simulation: - self.simulation_times.append(time.time()-self.begin_simulation) - self.begin_simulation = time.time() - if self.seed_map and self.simulation_count < len(self.seed_map): - _observation = self.env.reset(seed=self.seed_map[self.simulation_count], project=False) - self.simualation_rewards.append(0) - self.env_available = True - self.current_step = 0 - #_observation = list(_observation) - - _command_response = {} - _command_response['type'] = messages.FLATLAND_RL.ENV_RESET_RESPONSE - _command_response['payload'] = {} - _command_response['payload']['observation'] = _observation - if self.verbose: print("Responding with : ", _command_response) - _redis.rpush(command_response_channel, msgpack.packb(_command_response, default=m.encode, use_bin_type=True)) - else: - _command_response = {} - _command_response['type'] = messages.FLATLAND_RL.ENV_RESET_RESPONSE - _command_response['payload'] = {} - _command_response['payload']['observation'] = False - if self.verbose: print("Responding with : ", _command_response) - _redis.rpush(command_response_channel, msgpack.packb(_command_response, default=m.encode, use_bin_type=True)) elif command['type'] == messages.FLATLAND_RL.ENV_STEP: """ ENV_STEP - Request : Action array + Request : Action dict Respond with updated [observation,reward,done,info] after step """ - args = command['payload'] - action = args['action'] - if self.env and self.env_available: - [_observation, reward, done, info] = self.env.step(action) - else: - if self.env: - raise Exception("Attempt to call `step` function after max_steps={} in a single simulation. Please reset your environment before calling the `step` function after max_step s".format(self.max_steps)) - else: - raise Exception("Attempt to call `step` function on a non existent `env`") - self.reward += reward - self.simualation_rewards[-1] += reward - self.current_step += 1 - #_observation = np.array(_observation).tolist() - - if self.current_step >= self.max_steps: - _command_response = {} - _command_response['type'] = messages.FLATLAND_RL.ENV_STEP_RESPONSE - _command_response['payload'] = {} - _command_response['payload']['observation'] = _observation - _command_response['payload']['reward'] = reward - _command_response['payload']['done'] = True - _command_response['payload']['info'] = info - - """ - Mark env as unavailable until next reset - """ - self.env_available = False - else: - _command_response = {} - _command_response['type'] = messages.FLATLAND_RL.ENV_STEP_RESPONSE - _command_response['payload'] = {} - _command_response['payload']['observation'] = _observation - _command_response['payload']['reward'] = reward - _command_response['payload']['done'] = done - _command_response['payload']['info'] = info - - if done: - """ - Mark env as unavailable until next reset - """ - self.env_available = False - if self.verbose: print("Responding with : ", _command_response) - if self.verbose: print("Current Step : ", self.current_step) - _redis.rpush(command_response_channel, msgpack.packb(_command_response, default=m.encode, use_bin_type=True)) + self.handle_env_step(command) elif command['type'] == messages.FLATLAND_RL.ENV_SUBMIT: """ ENV_SUBMIT Submit the final cumulative reward """ - _response = {} - _response['type'] = messages.FLATLAND_RL.ENV_SUBMIT_RESPONSE - _payload = {} - _payload['mean_reward'] = np.float(self.reward)/len(self.seed_map) #Mean reward - _payload['simulation_rewards'] = self.simualation_rewards - _payload['simulation_times'] = self.simulation_times - _response['payload'] = _payload - _redis.rpush(command_response_channel, msgpack.packb(_response, default=m.encode, use_bin_type=True)) - elif command['type'] == messages.FLATLAND_RL.ENV_SUBMIT: - if self.verbose: print("Responding with : ", _response) - return _response + self.handle_env_submit(command) else: _error = self._error_template( "UNKNOWN_REQUEST:{}".format( str(command))) if self.verbose:print("Responding with : ", _error) - _redis.rpush(command_response_channel, msgpack.packb(_error, default=m.encode, use_bin_type=True)) + self.report_error( + self._error_template(str(e)), + command['response_channel']) return _error - except Exception as e: print("Error : ", str(e)) - _redis.rpush( command_response_channel, - msgpack.packb(self._error_template(str(e)), default=m.encode, use_bin_type=True)) + print(traceback.format_exc()) + self.report_error( + self._error_template(str(e)), + command['response_channel']) return self._error_template(str(e)) if __name__ == "__main__":