Skip to content
Snippets Groups Projects
Commit ad9d154c authored by spmohanty's avatar spmohanty
Browse files

Enable video generation only for a list of hand selected envs from the test set

parent b40368a1
No related branches found
No related tags found
No related merge requests found
......@@ -53,12 +53,16 @@ class FlatlandRemoteEvaluationService:
remote_db=0,
remote_password=None,
visualize=False,
video_generation_envs=[],
report=None,
verbose=False):
# Test Env folder Paths
self.test_env_folder = test_env_folder
self.video_generation_envs = video_generation_envs
self.video_generation_indices = []
self.env_file_paths = self.get_env_filepaths()
print(self.video_generation_indices)
# Logging and Reporting related vars
self.verbose = verbose
......@@ -141,6 +145,21 @@ class FlatlandRemoteEvaluationService:
env_paths.append(
os.path.join(root, file)
)
env_paths = sorted(env_paths)
print(self.video_generation_envs)
for _idx, env_path in enumerate(env_paths):
"""
Here we collect the indices of the environments for which
we need to generate the videos
We increment the simulation count on env_create
so the 1st simulation has an index of 1, when comparing in
env_step
"""
for vg_env in self.video_generation_envs:
print(vg_env, env_path)
if vg_env in env_path:
self.video_generation_indices.append(_idx+1)
return sorted(env_paths)
def instantiate_redis_connection_pool(self):
......@@ -367,13 +386,18 @@ class FlatlandRemoteEvaluationService:
# Record Frame
if self.visualize:
self.env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
self.env_renderer.gl.save_image(
os.path.join(
self.vizualization_folder_name,
"flatland_frame_{:04d}.png".format(self.record_frame_step)
))
self.record_frame_step += 1
self.env_renderer.render_env(show=False, show_observations=False, show_predictions=False)
"""
Only save the frames for environments which are separately provided
in video_generation_indices param
"""
if self.simulation_count in self.video_generation_indices:
self.env_renderer.gl.save_image(
os.path.join(
self.vizualization_folder_name,
"flatland_frame_{:04d}.png".format(self.record_frame_step)
))
self.record_frame_step += 1
# Build and send response
_command_response = {}
......@@ -575,7 +599,8 @@ if __name__ == "__main__":
test_env_folder=test_folder,
flatland_rl_service_id=args.service_id,
verbose=True,
visualize=True
visualize=True,
video_generation_envs=["Test_0/Level_1.pkl"]
)
result = grader.run()
if result['type'] == messages.FLATLAND_RL.ENV_SUBMIT_RESPONSE:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment