Commit 18e682e9 authored by nilabha's avatar nilabha

Added changes for updating videos via wandblogger

parent 2872d5b9
......@@ -16,14 +16,16 @@ class FlatlandRenderWrapper(RailEnv,gym.Env):
self.renderer = None
self.metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second': 10
'video.frames_per_second': 10,
'semantics.autoreset': True
}
if self.use_renderer:
self.initialize_renderer()
def reset(self, *args, **kwargs):
if self.use_renderer:
self.renderer.reset()
if self.renderer: #TODO: Errors with RLLib with renderer as None.
self.renderer.reset()
return super().reset(*args, **kwargs)
def render(self, mode='human'):
......@@ -59,10 +61,14 @@ class FlatlandRenderWrapper(RailEnv,gym.Env):
self.initialize_renderer(mode=self.use_renderer)
def close(self):
super().close()
if self.renderer:
try:
self.renderer.close_window()
self.renderer = None
except Exception as e:
# TODO: This causes an error with RLLib
# This is since the last step(Due to a stopping criteria) is skipped by rllib
# Due to this done is not true and the env does not close
# Finally the env is closed when RLLib exits but at that time there is no window
# and hence the error
print("Could Not close window due to:",e)
......@@ -7,6 +7,16 @@ from flatland.envs.rail_env import RailEnvActions
from envs.flatland.utils.flatland_render_wrapper import FlatlandRenderWrapper as RailEnv
import os, fnmatch
def find(pattern, path):
result = []
for root, dirs, files in os.walk(path):
for name in files:
if fnmatch.fnmatch(name, pattern):
result.append(os.path.join(root, name))
return result
class StepOutput(NamedTuple):
obs: Dict[int, Any] # depends on observation builder
reward: Dict[int, float]
......@@ -24,7 +34,8 @@ class FlatlandRllibWrapper(object):
metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second': 10
'video.frames_per_second': 10,
'semantics.autoreset': True
}
def __init__(self, rail_env: RailEnv, render = False, regenerate_rail_on_reset: bool = True,
......@@ -107,3 +118,17 @@ class FlatlandRllibWrapper(object):
def close(self):
self._env.close()
def save_video(self, video_folders=[]):
all_saved_videos = []
if len(video_folders) > 0:
for _video_folder in video_folders:
print("Finding videos from:",_video_folder)
_video_files = find('*.mp4', _video_folder)
print("Found videos:",_video_files)
if len(_video_files) > 0:
for _video_file in _video_files:
print("Saving video file:",_video_file)
all_saved_videos.append(_video_file)
return all_saved_videos
......@@ -12,8 +12,30 @@ from ray.rllib import MultiAgentEnv
from envs.flatland import get_generator_config
from envs.flatland.observations import make_obs
from envs.flatland.utils.rllib_wrapper import FlatlandRllibWrapper
from gym import wrappers
from gym.wrappers.monitor import Monitor
from datetime import datetime
import time
import os
def _after_step(self, observation, reward, done, info):
if not self.enabled: return done
if type(done)== dict:
_done_check = done['__all__']
else:
_done_check = done
if _done_check and self.env_semantics_autoreset:
# For envs with BlockingReset wrapping VNCEnv, this observation will be the first one of the new episode
self.reset_video_recorder()
self.episode_id += 1
self._flush()
# Record stats - Disabled as it causes error in multi-agent set up
# self.stats_recorder.after_step(observation, reward, done, info)
# Record video
self.video_recorder.capture_frame()
return done
class FlatlandSparse(MultiAgentEnv):
......@@ -23,11 +45,13 @@ class FlatlandSparse(MultiAgentEnv):
spec = None
metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second': 10
'video.frames_per_second': 10,
'semantics.autoreset': True
}
def __init__(self, env_config) -> None:
super().__init__()
self.video_folders = []
# TODO implement other generators
assert env_config['generator'] == 'sparse_rail_generator'
......@@ -39,9 +63,6 @@ class FlatlandSparse(MultiAgentEnv):
if env_config.worker_index == 0 and env_config.vector_index == 0:
print("=" * 50)
pprint(self._config)
print("=" * 50)
pprint(self._env_config)
print("=" * 50)
self._env = FlatlandRllibWrapper(
rail_env=self._launch(),
render=env_config['render'], # TODO need to fix gl compatibility first
......@@ -94,9 +115,9 @@ class FlatlandSparse(MultiAgentEnv):
obs_builder_object=self._observation.builder(),
remove_agents_at_target=False,
random_seed=self._config['seed'],
# Commented below line as here the env tries different configs,
# hence opening it is wasteful, morever the render has to be closed
# use_renderer=self._env_config['render']
# Should Below line be commented as here the env tries different configs,
# hence opening it can be wasteful, morever the render has to be closed
use_renderer=self._env_config['render']
)
env.reset()
except ValueError as e:
......@@ -107,12 +128,20 @@ class FlatlandSparse(MultiAgentEnv):
return env
def step(self, action_dict):
return self._env.step(action_dict)
obs, all_rewards, done, info = self._env.step(action_dict)
if done['__all__']:
uploaded_videos = self._env.save_video(self.video_folders)
self.close()
return obs, all_rewards, done, info
def reset(self):
if self._env_config['render']:
folder = "video_"+ datetime.now().strftime('%d-%b-%Y (%H:%M:%S.%f)')
self._env = wrappers.Monitor(self._env, folder, resume=True)
_cur_date = datetime.now().strftime('%d-%b-%Y (%H:%M:%S.%f')
folder = f"video_{_cur_date}"
Monitor._after_step =_after_step
self._env = Monitor(self._env, folder, resume=True)
self.video_folders.append(folder)
print("All Video folders created:",self.video_folders)
return self._env.reset()
def render(self,mode='human'):
......
......@@ -2,6 +2,16 @@ import numbers
import wandb
from ray import tune
from datetime import datetime
import os, fnmatch
def find(pattern, path):
result = []
for root, dirs, files in os.walk(path):
for name in files:
if fnmatch.fnmatch(name, pattern):
result.append(os.path.join(root, name))
return result
# ray 0.8.1 reorganized ray.tune.util -> ray.tune.utils
try:
......@@ -40,6 +50,16 @@ class WandbLogger(tune.logger.Logger):
continue
metrics[key] = value
wandb.log(metrics)
# _current_pid = result['pid']
# _video_file = f'openaigym.video.0.{_current_pid}.video*.mp4'
_video_file = f'*.mp4'
print("Current PID:",os.getpid())
_found_videos = find(_video_file,".")
for _found_video in _found_videos:
_splits = _found_video.split(os.sep)
if "video_" in _splits[-2]:
_video_file_name = _splits[-1]
wandb.log({_video_file_name: wandb.Video(_found_video, format="mp4")})
def close(self):
wandb.join()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment