Skip to content
Snippets Groups Projects
Commit 4ccccc1e authored by spmohanty's avatar spmohanty
Browse files

Addresses #117 - Add ability to pass in custom observation builder

parent 93c99c05
No related branches found
No related tags found
No related merge requests found
...@@ -73,3 +73,24 @@ class ObservationBuilder: ...@@ -73,3 +73,24 @@ class ObservationBuilder:
direction = np.zeros(4) direction = np.zeros(4)
direction[agent.direction] = 1 direction[agent.direction] = 1
return direction return direction
class DummyObservationBuilder(ObservationBuilder):
"""
DummyObservationBuilder class which returns dummy observations
This is used in the evaluation service
"""
def __init__(self):
self.observation_space = ()
def _set_env(self, env):
self.env = env
def reset(self):
pass
def get_many(self, handles=[]):
return True
def get(self, handle=0):
return True
...@@ -18,6 +18,14 @@ logger.setLevel(logging.INFO) ...@@ -18,6 +18,14 @@ logger.setLevel(logging.INFO)
m.patch() m.patch()
def are_dicts_equal(d1, d2):
""" return True if all keys and values are the same """
return all(k in d2 and d1[k] == d2[k]
for k in d1) \
and all(k in d1 and d1[k] == d2[k]
for k in d2)
class FlatlandRemoteClient(object): class FlatlandRemoteClient(object):
""" """
Redis client to interface with flatland-rl remote-evaluation-service Redis client to interface with flatland-rl remote-evaluation-service
...@@ -133,10 +141,16 @@ class FlatlandRemoteClient(object): ...@@ -133,10 +141,16 @@ class FlatlandRemoteClient(object):
else: else:
return True return True
def env_create(self, params={}): def env_create(self, obs_builder_object):
"""
Create a local env and remote env on which the
local agent can operate.
The observation builder is only used in the local env
and the remote env uses a DummyObservationBuilder
"""
_request = {} _request = {}
_request['type'] = messages.FLATLAND_RL.ENV_CREATE _request['type'] = messages.FLATLAND_RL.ENV_CREATE
_request['payload'] = params _request['payload'] = {}
_response = self._blocking_request(_request) _response = self._blocking_request(_request)
observation = _response['payload']['observation'] observation = _response['payload']['observation']
...@@ -151,17 +165,15 @@ class FlatlandRemoteClient(object): ...@@ -151,17 +165,15 @@ class FlatlandRemoteClient(object):
width=1, width=1,
height=1, height=1,
rail_generator=rail_from_file(test_env_file_path), rail_generator=rail_from_file(test_env_file_path),
obs_builder_object=TreeObsForRailEnv( obs_builder_object=obs_builder_object
max_depth=3,
predictor=ShortestPathPredictorForRailEnv()
)
) )
self.env._max_episode_steps = \ self.env._max_episode_steps = \
int(1.5 * (self.env.width + self.env.height)) int(1.5 * (self.env.width + self.env.height))
self.env.reset() local_observation = self.env.reset()
# Use the observation from the remote service instead # Use the local observation
return observation # as the remote server uses a dummy observation builder
return local_observation
def env_step(self, action, render=False): def env_step(self, action, render=False):
""" """
...@@ -173,11 +185,25 @@ class FlatlandRemoteClient(object): ...@@ -173,11 +185,25 @@ class FlatlandRemoteClient(object):
_request['payload']['action'] = action _request['payload']['action'] = action
_response = self._blocking_request(_request) _response = self._blocking_request(_request)
_payload = _response['payload'] _payload = _response['payload']
observation = _payload['observation']
# remote_observation = _payload['observation']
reward = _payload['reward'] reward = _payload['reward']
done = _payload['done'] done = _payload['done']
info = _payload['info'] info = _payload['info']
return [observation, reward, done, info]
# Replicate the action in the local env
local_observation, local_rewards, local_done, local_info = \
self.env.step(action)
assert are_dicts_equal(reward, local_rewards)
assert are_dicts_equal(done, local_done)
# Return local_observation instead of remote_observation
# as the remote_observation is build using a dummy observation
# builder
# We return the remote rewards and done as they are the
# once used by the evaluator
return [local_observation, reward, done, info]
def submit(self): def submit(self):
_request = {} _request = {}
...@@ -196,28 +222,37 @@ class FlatlandRemoteClient(object): ...@@ -196,28 +222,37 @@ class FlatlandRemoteClient(object):
if __name__ == "__main__": if __name__ == "__main__":
env_client = FlatlandRemoteClient() remote_client = FlatlandRemoteClient()
def my_controller(obs, _env): def my_controller(obs, _env):
_action = {} _action = {}
for _idx, _ in enumerate(_env.agents): for _idx, _ in enumerate(_env.agents):
_action[_idx] = np.random.randint(0, 5) _action[_idx] = np.random.randint(0, 5)
return _action return _action
my_observation_builder = TreeObsForRailEnv(max_depth=3,
predictor=ShortestPathPredictorForRailEnv())
episode = 0 episode = 0
obs = True obs = True
while obs: while obs:
obs = env_client.env_create() obs = remote_client.env_create(
obs_builder_object=my_observation_builder
)
if not obs: if not obs:
"""
The remote env returns False as the first obs
when it is done evaluating all the individual episodes
"""
break break
print("Episode : {}".format(episode)) print("Episode : {}".format(episode))
episode += 1 episode += 1
print(env_client.env.dones['__all__']) print(remote_client.env.dones['__all__'])
while True: while True:
action = my_controller(obs, env_client.env) action = my_controller(obs, remote_client.env)
observation, all_rewards, done, info = env_client.env_step(action) observation, all_rewards, done, info = remote_client.env_step(action)
if done['__all__']: if done['__all__']:
print("Current Episode : ", episode) print("Current Episode : ", episode)
print("Episode Done") print("Episode Done")
...@@ -225,6 +260,6 @@ if __name__ == "__main__": ...@@ -225,6 +260,6 @@ if __name__ == "__main__":
break break
print("Evaluation Complete...") print("Evaluation Complete...")
print(env_client.submit()) print(remote_client.submit())
...@@ -3,8 +3,7 @@ from __future__ import print_function ...@@ -3,8 +3,7 @@ from __future__ import print_function
import redis import redis
from flatland.envs.generators import rail_from_file from flatland.envs.generators import rail_from_file
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.observations import TreeObsForRailEnv from flatland.core.env_observation_builder import DummyObservationBuilder
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.evaluators import messages from flatland.evaluators import messages
import numpy as np import numpy as np
import msgpack import msgpack
...@@ -235,7 +234,6 @@ class FlatlandRemoteEvaluationService: ...@@ -235,7 +234,6 @@ class FlatlandRemoteEvaluationService:
Add a high level summary of everything thats Add a high level summary of everything thats
hapenning here. hapenning here.
""" """
env_params = command["payload"] # noqa F841
if self.simulation_count < len(self.env_file_paths): if self.simulation_count < len(self.env_file_paths):
""" """
...@@ -244,19 +242,11 @@ class FlatlandRemoteEvaluationService: ...@@ -244,19 +242,11 @@ class FlatlandRemoteEvaluationService:
test_env_file_path = self.env_file_paths[self.simulation_count] test_env_file_path = self.env_file_paths[self.simulation_count]
del self.env del self.env
# TODO : Use env_params dictionary to instantiate
# the RailEnv
# Maybe use a gin-like interface ?
# Needs discussion with Erik + Giacomo
# -Mohanty
self.env = RailEnv( self.env = RailEnv(
width=1, width=1,
height=1, height=1,
rail_generator=rail_from_file(test_env_file_path), rail_generator=rail_from_file(test_env_file_path),
obs_builder_object=TreeObsForRailEnv( obs_builder_object=DummyObservationBuilder()
max_depth=3,
predictor=ShortestPathPredictorForRailEnv()
)
) )
# Set max episode steps allowed # Set max episode steps allowed
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment