Commit 6e97f18d authored by nilabha's avatar nilabha
Browse files

Changes to include support for video upload to wandb

parent 8aa6ce9b
...@@ -127,4 +127,8 @@ dmypy.json ...@@ -127,4 +127,8 @@ dmypy.json
.pyre/ .pyre/
# misc # misc
.idea .idea
\ No newline at end of file
# custom extras
small_tree_video/
test.yaml
\ No newline at end of file
...@@ -12,7 +12,7 @@ from ray.rllib import MultiAgentEnv ...@@ -12,7 +12,7 @@ from ray.rllib import MultiAgentEnv
from envs.flatland import get_generator_config from envs.flatland import get_generator_config
from envs.flatland.observations import make_obs from envs.flatland.observations import make_obs
from envs.flatland.utils.rllib_wrapper import FlatlandRllibWrapper from envs.flatland.utils.rllib_wrapper import FlatlandRllibWrapper
from gym.wrappers.monitor import Monitor from gym.wrappers import monitor
from datetime import datetime from datetime import datetime
import time import time
import os import os
...@@ -65,7 +65,7 @@ class FlatlandSparse(MultiAgentEnv): ...@@ -65,7 +65,7 @@ class FlatlandSparse(MultiAgentEnv):
pprint(self._config) pprint(self._config)
self._env = FlatlandRllibWrapper( self._env = FlatlandRllibWrapper(
rail_env=self._launch(), rail_env=self._launch(),
render=env_config['render'], # TODO need to fix gl compatibility first render=env_config.get('render'), # TODO need to fix gl compatibility first
regenerate_rail_on_reset=self._config['regenerate_rail_on_reset'], regenerate_rail_on_reset=self._config['regenerate_rail_on_reset'],
regenerate_schedule_on_reset=self._config['regenerate_schedule_on_reset'] regenerate_schedule_on_reset=self._config['regenerate_schedule_on_reset']
) )
...@@ -117,7 +117,7 @@ class FlatlandSparse(MultiAgentEnv): ...@@ -117,7 +117,7 @@ class FlatlandSparse(MultiAgentEnv):
random_seed=self._config['seed'], random_seed=self._config['seed'],
# Should Below line be commented as here the env tries different configs, # Should Below line be commented as here the env tries different configs,
# hence opening it can be wasteful, morever the render has to be closed # hence opening it can be wasteful, morever the render has to be closed
use_renderer=self._env_config['render'] use_renderer=self._env_config.get('render')
) )
env.reset() env.reset()
except ValueError as e: except ValueError as e:
...@@ -134,16 +134,17 @@ class FlatlandSparse(MultiAgentEnv): ...@@ -134,16 +134,17 @@ class FlatlandSparse(MultiAgentEnv):
return obs, all_rewards, done, info return obs, all_rewards, done, info
def reset(self): def reset(self):
if self._env_config['render']: if self._env_config.get('render', None):
_cur_date = datetime.now().strftime('%d-%b-%Y (%H:%M:%S.%f') env_name="flatland"
folder = f"video_{_cur_date}" monitor.FILE_PREFIX = env_name
Monitor._after_step =_after_step folder = self._env_config.get('video_dir',env_name)
self._env = Monitor(self._env, folder, resume=True) monitor.Monitor._after_step =_after_step
self._env = monitor.Monitor(self._env, folder, resume=True)
self.video_folders.append(folder) self.video_folders.append(folder)
return self._env.reset() return self._env.reset()
def render(self,mode='human'): def render(self,mode='human'):
return self._env.render(self._env_config['render']) return self._env.render(self._env_config.get('render'))
def close(self): def close(self):
self._env.close() self._env.close()
...@@ -2,7 +2,7 @@ flatland-render-test: ...@@ -2,7 +2,7 @@ flatland-render-test:
run: PPO run: PPO
env: flatland_sparse env: flatland_sparse
stop: stop:
training_iteration: 3 training_iteration: 6
# timesteps_total: 5000 # 1e7 # timesteps_total: 5000 # 1e7
checkpoint_freq: 10 checkpoint_freq: 10
checkpoint_at_end: True checkpoint_at_end: True
...@@ -30,6 +30,9 @@ flatland-render-test: ...@@ -30,6 +30,9 @@ flatland-render-test:
env_config: env_config:
observation: tree observation: tree
render: human render: human
# For saving videos in custom folder and to wandb.
# By default if not specified folder is flatland
video_dir: small_tree_video
observation_config: observation_config:
max_depth: 2 max_depth: 2
shortest_path_max_depth: 30 shortest_path_max_depth: 30
...@@ -42,9 +45,10 @@ flatland-render-test: ...@@ -42,9 +45,10 @@ flatland-render-test:
entity: nilabha2007 entity: nilabha2007
tags: ["small_v0", "tree_obs"] # TODO should be set programmatically tags: ["small_v0", "tree_obs"] # TODO should be set programmatically
reinit: True reinit: True
# monitor_gym: True # monitor_gym: True # Wandb video doesn't seem to work
model: model:
fcnet_activation: relu fcnet_activation: relu
fcnet_hiddens: [256, 256] fcnet_hiddens: [256, 256]
vf_share_layers: True # False vf_share_layers: True # False
...@@ -4,6 +4,8 @@ import wandb ...@@ -4,6 +4,8 @@ import wandb
from ray import tune from ray import tune
from datetime import datetime from datetime import datetime
import os, fnmatch import os, fnmatch
import shutil
def find(pattern, path): def find(pattern, path):
result = [] result = []
...@@ -31,6 +33,33 @@ class WandbLogger(tune.logger.Logger): ...@@ -31,6 +33,33 @@ class WandbLogger(tune.logger.Logger):
def _init(self): def _init(self):
self._config = None self._config = None
wandb.init(**self.config.get("env_config", {}).get("wandb", {})) wandb.init(**self.config.get("env_config", {}).get("wandb", {}))
self.reset_state()
if self.config.get("env_config", {}).get("render", None):
'''
Cleans or Resume folder based on resume
To avoid overwriting current video folder
Add `resume: True` under `env_config`
'''
resume = self.config.get("env_config", {}).get("resume", False)
self.clear_folders(resume)
def clear_folders(self, resume):
env_name="flatland"
self._save_folder = self.config.get("env_config", {}).get("video_dir", env_name)
if not resume:
if os.path.exists(self._save_folder):
try:
shutil.rmtree(self._save_folder)
except OSError as e:
print ("Error: %s - %s." % (e.filename, e.strerror))
def reset_state(self):
# Holds list of uploaded/moved files so that we dont upload them again
self._upload_files = []
# Holds information of env state and put them in an unique file name
# and maps it to the original video file
self._file_map = {}
self._save_folder = None
def on_result(self, result): def on_result(self, result):
config = result.get("config") config = result.get("config")
...@@ -50,16 +79,57 @@ class WandbLogger(tune.logger.Logger): ...@@ -50,16 +79,57 @@ class WandbLogger(tune.logger.Logger):
continue continue
metrics[key] = value metrics[key] = value
wandb.log(metrics) wandb.log(metrics)
# _current_pid = result['pid']
# _video_file = f'openaigym.video.0.{_current_pid}.video*.mp4' if self.config.get("env_config", {}).get("render", None):
_video_file = f'*.mp4' # uploading relevant videos to wandb
print("Current PID:",os.getpid()) # we do this step to ensure any config changes done during training is incorporated
_found_videos = find(_video_file,".") # resume is set to True as we don't want to delete any older videos
for _found_video in _found_videos: self.clear_folders(resume= True)
_splits = _found_video.split(os.sep)
if "video_" in _splits[-2]: if self._save_folder:
_video_file_name = _splits[-1] iterations = result['training_iteration']
wandb.log({_video_file_name: wandb.Video(_found_video, format="mp4")}) steps = result['timesteps_total']
perc_comp_mean = result['custom_metrics'].get('percentage_complete_mean',0)*100
# We ignore *1.mp4 videos which just has the last frame
_video_file = f'*0.mp4'
_found_videos = find(_video_file,self._save_folder)
_found_videos = list(set(_found_videos) - set(self._upload_files))
# Sort by create time for uploading to wandb
_found_videos.sort(key=os.path.getctime)
for _found_video in _found_videos:
_splits = _found_video.split(os.sep)
_check_file = os.stat(_found_video)
_video_file = _splits[-1]
_file_split = _video_file.split('.')
_video_file_name = _file_split[0] + "-" + str(iterations)
_original_name = ".".join(_file_split[2:-1])
_video_file_name = ".".join([str(_video_file_name),str(steps),str(int(perc_comp_mean)),_original_name,str(_check_file.st_ctime)])
_key = _found_video # Use the video file path as key to identify the video_name
if not self._file_map.get(_key):
# Allocate steps, iteration, completion rate to the earliest case
# when the video file was first created. Discard recent file names
# TODO: Cannot match exact env details on which video was created
# and hence defaulting to the env details from when the video was first created
# To help identify we must record the video file with the env iteration or/and steps etc.
# Using the env details when the video was created may be useful when recording video during evaluation
# where we are more interested in the current training state
self._file_map[_key]=_video_file_name
# We only move videos that have been flushed out.
# This is done by checking against a threshold size of 1000 bytes
if _check_file.st_size > 1000:
_video_file_name = self._file_map.get(_key,"Unknown")
wandb.log({_video_file_name: wandb.Video(_found_video, format="mp4")})
try:
# Move upload videos and their meta data once done to the logdir
src = _found_video
dst = os.path.join(self.logdir,_video_file)
shutil.move(src, dst)
shutil.move(src.replace("mp4","meta.json"), dst.replace("mp4","meta.json"))
except OSError as e:
print ("Error: %s - %s." % (e.filename, e.strerror))
self._upload_files.append(_found_video)
def close(self): def close(self):
wandb.join() wandb.join()
self.reset_state()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment