Commit 0d8708a0 authored by MasterScrat's avatar MasterScrat

Fixed bug where training process would never end

parent 79edf57b
import fnmatch
import multiprocessing
import numbers
from pprint import pprint
import os
import shutil
import wandb
from ray import tune
import os, fnmatch
import shutil
from ray.tune.utils import flatten_dict
def find(pattern, path):
......@@ -16,12 +18,6 @@ def find(pattern, path):
result.append(os.path.join(root, name))
return result
# ray 0.8.1 reorganized ray.tune.util -> ray.tune.utils
try:
from ray.tune.utils import flatten_dict
except ImportError:
from ray.tune.util import flatten_dict
class WandbLogger(tune.logger.Logger):
"""Pass WandbLogger to the loggers argument of tune.run
......@@ -45,14 +41,14 @@ class WandbLogger(tune.logger.Logger):
self.clear_folders(resume)
def clear_folders(self, resume):
env_name="flatland"
env_name = "flatland"
self._save_folder = self.config.get("env_config", {}).get("video_dir", env_name)
if not resume:
if os.path.exists(self._save_folder):
try:
shutil.rmtree(self._save_folder)
except OSError as e:
print ("Error: %s - %s." % (e.filename, e.strerror))
print("Error: %s - %s." % (e.filename, e.strerror))
def reset_state(self):
# Holds list of uploaded/moved files so that we dont upload them again
......@@ -94,21 +90,21 @@ class WandbLogger(tune.logger.Logger):
# uploading relevant videos to wandb
# we do this step to ensure any config changes done during training is incorporated
# resume is set to True as we don't want to delete any older videos
self.clear_folders(resume= True)
self.clear_folders(resume=True)
if self._save_folder:
metrics = self.update_video_metrics(result,metrics)
metrics = self.update_video_metrics(result, metrics)
queue.put(metrics)
def update_video_metrics(self, result, metrics):
iterations = result['training_iteration']
steps = result['timesteps_total']
perc_comp_mean = result['custom_metrics'].get('percentage_complete_mean',0)*100
perc_comp_mean = result['custom_metrics'].get('percentage_complete_mean', 0) * 100
# We ignore *1.mp4 videos which just has the last frame
_video_file = f'*0.mp4'
_found_videos = find(_video_file,self._save_folder)
#_found_videos = list(set(_found_videos) - set(self._upload_files))
_found_videos = find(_video_file, self._save_folder)
# _found_videos = list(set(_found_videos) - set(self._upload_files))
# Sort by create time for uploading to wandb
_found_videos.sort(key=os.path.getctime)
for _found_video in _found_videos:
......@@ -118,7 +114,7 @@ class WandbLogger(tune.logger.Logger):
_file_split = _video_file.split('.')
_video_file_name = _file_split[0] + "-" + str(iterations)
_original_name = ".".join(_file_split[2:-1])
_video_file_name = ".".join([str(_video_file_name),str(steps),str(int(perc_comp_mean)),_original_name,str(_check_file.st_ctime)])
_video_file_name = ".".join([str(_video_file_name), str(steps), str(int(perc_comp_mean)), _original_name, str(_check_file.st_ctime)])
_key = _found_video # Use the video file path as key to identify the video_name
if not self._file_map.get(_key):
# Allocate steps, iteration, completion rate to the earliest case
......@@ -128,13 +124,13 @@ class WandbLogger(tune.logger.Logger):
# To help identify we must record the video file with the env iteration or/and steps etc.
# Using the env details when the video was created may be useful when recording video during evaluation
# where we are more interested in the current training state
self._file_map[_key]=_video_file_name
self._file_map[_key] = _video_file_name
# We only move videos that have been flushed out.
# This is done by checking against a threshold size of 1000 bytes
# Also check if file has changed from last time
if _check_file.st_size > 1000 and _check_file.st_ctime > self._upload_files.get(_found_video, True):
_video_file_name = self._file_map.get(_key,"Unknown")
if _check_file.st_size > 1000 and _check_file.st_ctime > self._upload_files.get(_found_video, True):
_video_file_name = self._file_map.get(_key, "Unknown")
# wandb.log({_video_file_name: wandb.Video(_found_video, format="mp4")})
metrics[_video_file_name] = wandb.Video(_found_video, format="mp4")
......@@ -143,7 +139,12 @@ class WandbLogger(tune.logger.Logger):
return metrics
def close(self):
# kills logger processes
for queue in self.metrics_queue_dict.values():
metrics = {"KILL": True}
queue.put(metrics)
wandb.join()
all_uploaded_videos = self._upload_files.keys()
for _found_video in all_uploaded_videos:
......@@ -151,11 +152,11 @@ class WandbLogger(tune.logger.Logger):
# Copy upload videos and their meta data once done to the logdir
src = _found_video
_video_file = _found_video.split(os.sep)[-1]
dst = os.path.join(self.logdir,_video_file)
dst = os.path.join(self.logdir, _video_file)
shutil.copy2(src, dst)
shutil.copy2(src.replace("mp4","meta.json"), dst.replace("mp4","meta.json"))
shutil.copy2(src.replace("mp4", "meta.json"), dst.replace("mp4", "meta.json"))
except OSError as e:
print ("Error: %s - %s." % (e.filename, e.strerror))
print("Error: %s - %s." % (e.filename, e.strerror))
self.reset_state()
......@@ -176,4 +177,8 @@ def wandb_process(queue, config):
while True:
metrics = queue.get()
run.log(metrics)
\ No newline at end of file
if "KILL" in metrics:
break
run.log(metrics)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment