Commit 6e97f18d authored by nilabha's avatar nilabha
Browse files

Changes to include support for video upload to wandb

parent 8aa6ce9b
......@@ -127,4 +127,8 @@ dmypy.json
# misc
\ No newline at end of file
# custom extras
\ No newline at end of file
......@@ -12,7 +12,7 @@ from ray.rllib import MultiAgentEnv
from envs.flatland import get_generator_config
from envs.flatland.observations import make_obs
from envs.flatland.utils.rllib_wrapper import FlatlandRllibWrapper
from gym.wrappers.monitor import Monitor
from gym.wrappers import monitor
from datetime import datetime
import time
import os
......@@ -65,7 +65,7 @@ class FlatlandSparse(MultiAgentEnv):
self._env = FlatlandRllibWrapper(
render=env_config['render'], # TODO need to fix gl compatibility first
render=env_config.get('render'), # TODO need to fix gl compatibility first
......@@ -117,7 +117,7 @@ class FlatlandSparse(MultiAgentEnv):
# Should Below line be commented as here the env tries different configs,
# hence opening it can be wasteful, morever the render has to be closed
except ValueError as e:
......@@ -134,16 +134,17 @@ class FlatlandSparse(MultiAgentEnv):
return obs, all_rewards, done, info
def reset(self):
if self._env_config['render']:
_cur_date ='%d-%b-%Y (%H:%M:%S.%f')
folder = f"video_{_cur_date}"
Monitor._after_step =_after_step
self._env = Monitor(self._env, folder, resume=True)
if self._env_config.get('render', None):
monitor.FILE_PREFIX = env_name
folder = self._env_config.get('video_dir',env_name)
monitor.Monitor._after_step =_after_step
self._env = monitor.Monitor(self._env, folder, resume=True)
return self._env.reset()
def render(self,mode='human'):
return self._env.render(self._env_config['render'])
return self._env.render(self._env_config.get('render'))
def close(self):
......@@ -2,7 +2,7 @@ flatland-render-test:
run: PPO
env: flatland_sparse
training_iteration: 3
training_iteration: 6
# timesteps_total: 5000 # 1e7
checkpoint_freq: 10
checkpoint_at_end: True
......@@ -30,6 +30,9 @@ flatland-render-test:
observation: tree
render: human
# For saving videos in custom folder and to wandb.
# By default if not specified folder is flatland
video_dir: small_tree_video
max_depth: 2
shortest_path_max_depth: 30
......@@ -42,9 +45,10 @@ flatland-render-test:
entity: nilabha2007
tags: ["small_v0", "tree_obs"] # TODO should be set programmatically
reinit: True
# monitor_gym: True
# monitor_gym: True # Wandb video doesn't seem to work
fcnet_activation: relu
fcnet_hiddens: [256, 256]
vf_share_layers: True # False
......@@ -4,6 +4,8 @@ import wandb
from ray import tune
from datetime import datetime
import os, fnmatch
import shutil
def find(pattern, path):
result = []
......@@ -31,6 +33,33 @@ class WandbLogger(tune.logger.Logger):
def _init(self):
self._config = None
wandb.init(**self.config.get("env_config", {}).get("wandb", {}))
if self.config.get("env_config", {}).get("render", None):
Cleans or Resume folder based on resume
To avoid overwriting current video folder
Add `resume: True` under `env_config`
resume = self.config.get("env_config", {}).get("resume", False)
def clear_folders(self, resume):
self._save_folder = self.config.get("env_config", {}).get("video_dir", env_name)
if not resume:
if os.path.exists(self._save_folder):
except OSError as e:
print ("Error: %s - %s." % (e.filename, e.strerror))
def reset_state(self):
# Holds list of uploaded/moved files so that we dont upload them again
self._upload_files = []
# Holds information of env state and put them in an unique file name
# and maps it to the original video file
self._file_map = {}
self._save_folder = None
def on_result(self, result):
config = result.get("config")
......@@ -50,16 +79,57 @@ class WandbLogger(tune.logger.Logger):
metrics[key] = value
# _current_pid = result['pid']
# _video_file = f'{_current_pid}.video*.mp4'
_video_file = f'*.mp4'
print("Current PID:",os.getpid())
_found_videos = find(_video_file,".")
for _found_video in _found_videos:
_splits = _found_video.split(os.sep)
if "video_" in _splits[-2]:
_video_file_name = _splits[-1]
wandb.log({_video_file_name: wandb.Video(_found_video, format="mp4")})
if self.config.get("env_config", {}).get("render", None):
# uploading relevant videos to wandb
# we do this step to ensure any config changes done during training is incorporated
# resume is set to True as we don't want to delete any older videos
self.clear_folders(resume= True)
if self._save_folder:
iterations = result['training_iteration']
steps = result['timesteps_total']
perc_comp_mean = result['custom_metrics'].get('percentage_complete_mean',0)*100
# We ignore *1.mp4 videos which just has the last frame
_video_file = f'*0.mp4'
_found_videos = find(_video_file,self._save_folder)
_found_videos = list(set(_found_videos) - set(self._upload_files))
# Sort by create time for uploading to wandb
for _found_video in _found_videos:
_splits = _found_video.split(os.sep)
_check_file = os.stat(_found_video)
_video_file = _splits[-1]
_file_split = _video_file.split('.')
_video_file_name = _file_split[0] + "-" + str(iterations)
_original_name = ".".join(_file_split[2:-1])
_video_file_name = ".".join([str(_video_file_name),str(steps),str(int(perc_comp_mean)),_original_name,str(_check_file.st_ctime)])
_key = _found_video # Use the video file path as key to identify the video_name
if not self._file_map.get(_key):
# Allocate steps, iteration, completion rate to the earliest case
# when the video file was first created. Discard recent file names
# TODO: Cannot match exact env details on which video was created
# and hence defaulting to the env details from when the video was first created
# To help identify we must record the video file with the env iteration or/and steps etc.
# Using the env details when the video was created may be useful when recording video during evaluation
# where we are more interested in the current training state
# We only move videos that have been flushed out.
# This is done by checking against a threshold size of 1000 bytes
if _check_file.st_size > 1000:
_video_file_name = self._file_map.get(_key,"Unknown")
wandb.log({_video_file_name: wandb.Video(_found_video, format="mp4")})
# Move upload videos and their meta data once done to the logdir
src = _found_video
dst = os.path.join(self.logdir,_video_file)
shutil.move(src, dst)
shutil.move(src.replace("mp4","meta.json"), dst.replace("mp4","meta.json"))
except OSError as e:
print ("Error: %s - %s." % (e.filename, e.strerror))
def close(self):
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment