Skip to content
Snippets Groups Projects
Commit ad9d154c authored by spmohanty's avatar spmohanty
Browse files

Enable video generation only for a list of hand selected envs from the test set

parent b40368a1
No related branches found
No related tags found
No related merge requests found
...@@ -53,12 +53,16 @@ class FlatlandRemoteEvaluationService: ...@@ -53,12 +53,16 @@ class FlatlandRemoteEvaluationService:
remote_db=0, remote_db=0,
remote_password=None, remote_password=None,
visualize=False, visualize=False,
video_generation_envs=[],
report=None, report=None,
verbose=False): verbose=False):
# Test Env folder Paths # Test Env folder Paths
self.test_env_folder = test_env_folder self.test_env_folder = test_env_folder
self.video_generation_envs = video_generation_envs
self.video_generation_indices = []
self.env_file_paths = self.get_env_filepaths() self.env_file_paths = self.get_env_filepaths()
print(self.video_generation_indices)
# Logging and Reporting related vars # Logging and Reporting related vars
self.verbose = verbose self.verbose = verbose
...@@ -141,6 +145,21 @@ class FlatlandRemoteEvaluationService: ...@@ -141,6 +145,21 @@ class FlatlandRemoteEvaluationService:
env_paths.append( env_paths.append(
os.path.join(root, file) os.path.join(root, file)
) )
env_paths = sorted(env_paths)
print(self.video_generation_envs)
for _idx, env_path in enumerate(env_paths):
"""
Here we collect the indices of the environments for which
we need to generate the videos
We increment the simulation count on env_create
so the 1st simulation has an index of 1, when comparing in
env_step
"""
for vg_env in self.video_generation_envs:
print(vg_env, env_path)
if vg_env in env_path:
self.video_generation_indices.append(_idx+1)
return sorted(env_paths) return sorted(env_paths)
def instantiate_redis_connection_pool(self): def instantiate_redis_connection_pool(self):
...@@ -367,13 +386,18 @@ class FlatlandRemoteEvaluationService: ...@@ -367,13 +386,18 @@ class FlatlandRemoteEvaluationService:
# Record Frame # Record Frame
if self.visualize: if self.visualize:
self.env_renderer.render_env(show=True, show_observations=False, show_predictions=False) self.env_renderer.render_env(show=False, show_observations=False, show_predictions=False)
self.env_renderer.gl.save_image( """
os.path.join( Only save the frames for environments which are separately provided
self.vizualization_folder_name, in video_generation_indices param
"flatland_frame_{:04d}.png".format(self.record_frame_step) """
)) if self.simulation_count in self.video_generation_indices:
self.record_frame_step += 1 self.env_renderer.gl.save_image(
os.path.join(
self.vizualization_folder_name,
"flatland_frame_{:04d}.png".format(self.record_frame_step)
))
self.record_frame_step += 1
# Build and send response # Build and send response
_command_response = {} _command_response = {}
...@@ -575,7 +599,8 @@ if __name__ == "__main__": ...@@ -575,7 +599,8 @@ if __name__ == "__main__":
test_env_folder=test_folder, test_env_folder=test_folder,
flatland_rl_service_id=args.service_id, flatland_rl_service_id=args.service_id,
verbose=True, verbose=True,
visualize=True visualize=True,
video_generation_envs=["Test_0/Level_1.pkl"]
) )
result = grader.run() result = grader.run()
if result['type'] == messages.FLATLAND_RL.ENV_SUBMIT_RESPONSE: if result['type'] == messages.FLATLAND_RL.ENV_SUBMIT_RESPONSE:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment