Skip to content
Snippets Groups Projects
Commit 5b5d979a authored by spmohanty's avatar spmohanty
Browse files

Addresses #123 - Uses relative paths in evaluator service

parent 9dfef6cd
No related branches found
No related tags found
No related merge requests found
......@@ -46,6 +46,7 @@ class FlatlandRemoteClient(object):
remote_port=6379,
remote_db=0,
remote_password=None,
test_envs_root=None,
verbose=False):
self.remote_host = remote_host
......@@ -58,14 +59,22 @@ class FlatlandRemoteClient(object):
db=remote_db,
password=remote_password)
self.namespace = "flatland-rl"
try:
self.service_id = os.environ['FLATLAND_RL_SERVICE_ID']
except KeyError:
self.service_id = "FLATLAND_RL_SERVICE_ID"
self.service_id = os.getenv(
'FLATLAND_RL_SERVICE_ID',
'FLATLAND_RL_SERVICE_ID'
)
self.command_channel = "{}::{}::commands".format(
self.namespace,
self.service_id
)
if test_envs_root:
self.test_envs_root = test_envs_root
else:
self.test_envs_root = os.getenv(
'AICROWD_TESTS_FOLDER',
'/tmp/flatland_envs'
)
self.verbose = verbose
self.env = None
......@@ -161,6 +170,12 @@ class FlatlandRemoteClient(object):
return observation
test_env_file_path = _response['payload']['env_file_path']
print("Received Env : ", test_env_file_path)
test_env_file_path = os.path.join(
self.test_envs_root,
test_env_file_path
)
print("Current env path : ", test_env_file_path)
self.env = RailEnv(
width=1,
height=1,
......@@ -192,11 +207,15 @@ class FlatlandRemoteClient(object):
remote_info = _payload['info']
# Replicate the action in the local env
local_observation, local_rewards, local_done, local_info = \
local_observation, local_reward, local_done, local_info = \
self.env.step(action)
assert are_dicts_equal(remote_reward, local_rewards)
assert are_dicts_equal(remote_done, local_done)
print(local_reward)
if not are_dicts_equal(remote_reward, local_reward):
raise Exception("local and remote `reward` are diverging")
print(remote_reward, local_reward)
if not are_dicts_equal(remote_done, local_done):
raise Exception("local and remote `done` are diverging")
# Return local_observation instead of remote_observation
# as the remote_observation is build using a dummy observation
......
......@@ -61,9 +61,8 @@ class FlatlandRemoteEvaluationService:
# Test Env folder Paths
self.test_env_folder = test_env_folder
self.video_generation_envs = video_generation_envs
self.video_generation_indices = []
self.env_file_paths = self.get_env_filepaths()
print(self.video_generation_indices)
print(self.env_file_paths)
# Logging and Reporting related vars
self.verbose = verbose
......@@ -101,7 +100,7 @@ class FlatlandRemoteEvaluationService:
self.env = False
self.env_renderer = False
self.reward = 0
self.simulation_count = 0
self.simulation_count = -1
self.simulation_rewards = []
self.simulation_percentage_complete = []
self.simulation_steps = []
......@@ -151,18 +150,6 @@ class FlatlandRemoteEvaluationService:
x, self.test_env_folder
) for x in env_paths])
for _idx, env_path in enumerate(env_paths):
"""
Here we collect the indices of the environments for which
we need to generate the videos
We increment the simulation count on env_create
so the 1st simulation has an index of 1, when comparing in
env_step
"""
for vg_env in self.video_generation_envs:
if vg_env in env_path:
self.video_generation_indices.append(_idx+1)
return env_paths
def instantiate_redis_connection_pool(self):
......@@ -283,13 +270,19 @@ class FlatlandRemoteEvaluationService:
Add a high level summary of everything thats
hapenning here.
"""
self.simulation_count += 1
if self.simulation_count < len(self.env_file_paths):
"""
There are still test envs left that are yet to be evaluated
"""
test_env_file_path = self.env_file_paths[self.simulation_count]
print("__ Env Path : ", test_env_file_path)
test_env_file_path = os.path.join(
self.test_env_folder,
test_env_file_path
)
print("__ Processed Env Path : ", test_env_file_path)
del self.env
self.env = RailEnv(
width=1,
......@@ -299,15 +292,13 @@ class FlatlandRemoteEvaluationService:
)
if self.visualize:
if self.env_renderer:
del self.env_renderer
del self.env_renderer
self.env_renderer = RenderTool(self.env, gl="PILSVG", )
# Set max episode steps allowed
self.env._max_episode_steps = \
int(1.5 * (self.env.width + self.env.height))
self.simulation_count += 1
if self.begin_simulation:
# If begin simulation has already been initialized
# atleast once
......@@ -326,7 +317,7 @@ class FlatlandRemoteEvaluationService:
_command_response['type'] = messages.FLATLAND_RL.ENV_CREATE_RESPONSE
_command_response['payload'] = {}
_command_response['payload']['observation'] = _observation
_command_response['payload']['env_file_path'] = test_env_file_path
_command_response['payload']['env_file_path'] = self.env_file_paths[self.simulation_count]
else:
"""
All test env evaluations are complete
......@@ -389,12 +380,17 @@ class FlatlandRemoteEvaluationService:
# Record Frame
if self.visualize:
self.env_renderer.render_env(show=False, show_observations=False, show_predictions=False)
self.env_renderer.render_env(
show=False,
show_observations=False,
show_predictions=False
)
"""
Only save the frames for environments which are separately provided
in video_generation_indices param
"""
if self.simulation_count in self.video_generation_indices:
current_env_path = self.env_file_paths[self.simulation_count]
if current_env_path in self.video_generation_envs:
self.env_renderer.gl.save_image(
os.path.join(
self.vizualization_folder_name,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment