diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index 53e7a068b73f9907217777251bce0fdd704603be..060785f537251484ab4fd1c520fc92bc8a564cbc 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -73,3 +73,24 @@ class ObservationBuilder: direction = np.zeros(4) direction[agent.direction] = 1 return direction + +class DummyObservationBuilder(ObservationBuilder): + """ + DummyObservationBuilder class which returns dummy observations + This is used in the evaluation service + """ + + def __init__(self): + self.observation_space = () + + def _set_env(self, env): + self.env = env + + def reset(self): + pass + + def get_many(self, handles=[]): + return True + + def get(self, handle=0): + return True diff --git a/flatland/evaluators/client.py b/flatland/evaluators/client.py index 0853f498705d8322cbdfcd223c701e4a587aedf0..e8a8e623c8198698a5f074dc773b0a8cc8f2b69c 100644 --- a/flatland/evaluators/client.py +++ b/flatland/evaluators/client.py @@ -18,6 +18,14 @@ logger.setLevel(logging.INFO) m.patch() +def are_dicts_equal(d1, d2): + """ return True if all keys and values are the same """ + return all(k in d2 and d1[k] == d2[k] + for k in d1) \ + and all(k in d1 and d1[k] == d2[k] + for k in d2) + + class FlatlandRemoteClient(object): """ Redis client to interface with flatland-rl remote-evaluation-service @@ -133,10 +141,16 @@ class FlatlandRemoteClient(object): else: return True - def env_create(self, params={}): + def env_create(self, obs_builder_object): + """ + Create a local env and remote env on which the + local agent can operate. + The observation builder is only used in the local env + and the remote env uses a DummyObservationBuilder + """ _request = {} _request['type'] = messages.FLATLAND_RL.ENV_CREATE - _request['payload'] = params + _request['payload'] = {} _response = self._blocking_request(_request) observation = _response['payload']['observation'] @@ -151,17 +165,15 @@ class FlatlandRemoteClient(object): width=1, height=1, rail_generator=rail_from_file(test_env_file_path), - obs_builder_object=TreeObsForRailEnv( - max_depth=3, - predictor=ShortestPathPredictorForRailEnv() - ) + obs_builder_object=obs_builder_object ) self.env._max_episode_steps = \ int(1.5 * (self.env.width + self.env.height)) - self.env.reset() - # Use the observation from the remote service instead - return observation + local_observation = self.env.reset() + # Use the local observation + # as the remote server uses a dummy observation builder + return local_observation def env_step(self, action, render=False): """ @@ -173,11 +185,25 @@ class FlatlandRemoteClient(object): _request['payload']['action'] = action _response = self._blocking_request(_request) _payload = _response['payload'] - observation = _payload['observation'] + + # remote_observation = _payload['observation'] reward = _payload['reward'] done = _payload['done'] info = _payload['info'] - return [observation, reward, done, info] + + # Replicate the action in the local env + local_observation, local_rewards, local_done, local_info = \ + self.env.step(action) + + assert are_dicts_equal(reward, local_rewards) + assert are_dicts_equal(done, local_done) + + # Return local_observation instead of remote_observation + # as the remote_observation is build using a dummy observation + # builder + # We return the remote rewards and done as they are the + # once used by the evaluator + return [local_observation, reward, done, info] def submit(self): _request = {} @@ -196,28 +222,37 @@ class FlatlandRemoteClient(object): if __name__ == "__main__": - env_client = FlatlandRemoteClient() + remote_client = FlatlandRemoteClient() def my_controller(obs, _env): _action = {} for _idx, _ in enumerate(_env.agents): _action[_idx] = np.random.randint(0, 5) return _action + + my_observation_builder = TreeObsForRailEnv(max_depth=3, + predictor=ShortestPathPredictorForRailEnv()) episode = 0 obs = True - while obs: - obs = env_client.env_create() + while obs: + obs = remote_client.env_create( + obs_builder_object=my_observation_builder + ) if not obs: + """ + The remote env returns False as the first obs + when it is done evaluating all the individual episodes + """ break print("Episode : {}".format(episode)) episode += 1 - print(env_client.env.dones['__all__']) + print(remote_client.env.dones['__all__']) while True: - action = my_controller(obs, env_client.env) - observation, all_rewards, done, info = env_client.env_step(action) + action = my_controller(obs, remote_client.env) + observation, all_rewards, done, info = remote_client.env_step(action) if done['__all__']: print("Current Episode : ", episode) print("Episode Done") @@ -225,6 +260,6 @@ if __name__ == "__main__": break print("Evaluation Complete...") - print(env_client.submit()) + print(remote_client.submit()) diff --git a/flatland/evaluators/service.py b/flatland/evaluators/service.py index 1729196ce01cc7507d93f3de20d139e29ce9f06e..3c1fb2e9013896990204e291e239c4de4becc21c 100644 --- a/flatland/evaluators/service.py +++ b/flatland/evaluators/service.py @@ -3,8 +3,7 @@ from __future__ import print_function import redis from flatland.envs.generators import rail_from_file from flatland.envs.rail_env import RailEnv -from flatland.envs.observations import TreeObsForRailEnv -from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.core.env_observation_builder import DummyObservationBuilder from flatland.evaluators import messages import numpy as np import msgpack @@ -235,7 +234,6 @@ class FlatlandRemoteEvaluationService: Add a high level summary of everything thats hapenning here. """ - env_params = command["payload"] # noqa F841 if self.simulation_count < len(self.env_file_paths): """ @@ -244,19 +242,11 @@ class FlatlandRemoteEvaluationService: test_env_file_path = self.env_file_paths[self.simulation_count] del self.env - # TODO : Use env_params dictionary to instantiate - # the RailEnv - # Maybe use a gin-like interface ? - # Needs discussion with Erik + Giacomo - # -Mohanty self.env = RailEnv( width=1, height=1, rail_generator=rail_from_file(test_env_file_path), - obs_builder_object=TreeObsForRailEnv( - max_depth=3, - predictor=ShortestPathPredictorForRailEnv() - ) + obs_builder_object=DummyObservationBuilder() ) # Set max episode steps allowed