Commit 2872d5b9 authored by nilabha's avatar nilabha

Added support for video recording in RLLib. Initial Commit

parent 27e8d3bc
......@@ -2,7 +2,8 @@ import logging
import random
from flatland.envs.malfunction_generators import malfunction_from_params
from flatland.envs.rail_env import RailEnv
# from flatland.envs.rail_env import RailEnv
from envs.flatland.utils.flatland_render_wrapper import FlatlandRenderWrapper as RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
......
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
import gym
class FlatlandRenderWrapper(RailEnv,gym.Env):
# reward_range = (-float('inf'), float('inf'))
# spec = None
# # Set these in ALL subclasses
# observation_space = None
def __init__(self, use_renderer=False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.use_renderer = use_renderer
self.renderer = None
self.metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second': 10
}
if self.use_renderer:
self.initialize_renderer()
def reset(self, *args, **kwargs):
if self.use_renderer:
self.renderer.reset()
return super().reset(*args, **kwargs)
def render(self, mode='human'):
"""
This methods provides the option to render the
environment's behavior to a window which should be
readable to the human eye if mode is set to 'human'.
"""
if not self.use_renderer:
return
if not self.renderer:
self.initialize_renderer(mode=mode)
return self.update_renderer(mode=mode)
def initialize_renderer(self, mode="human"):
# Initiate the renderer
self.renderer = RenderTool(self, gl="PGL", # gl="TKPILSVG",
agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False,
screen_height=600, # Adjust these parameters to fit your resolution
screen_width=800) # Adjust these parameters to fit your resolution
def update_renderer(self, mode='human'):
image = self.renderer.render_env(show=True, show_observations=False, show_predictions=False,
return_image=True)
return image[:,:,:3]
def set_renderer(self, renderer):
self.use_renderer = renderer
if self.use_renderer:
self.initialize_renderer(mode=self.use_renderer)
def close(self):
if self.renderer:
try:
self.renderer.close_window()
self.renderer = None
except Exception as e:
# TODO: This causes an error with RLLib
print("Could Not close window due to:",e)
......@@ -2,8 +2,10 @@ from collections import defaultdict
from typing import Dict, NamedTuple, Any, Optional
import gym
from flatland.envs.rail_env import RailEnv, RailEnvActions
from gym import wrappers
from flatland.envs.rail_env import RailEnvActions
from envs.flatland.utils.flatland_render_wrapper import FlatlandRenderWrapper as RailEnv
class StepOutput(NamedTuple):
obs: Dict[int, Any] # depends on observation builder
......@@ -14,7 +16,18 @@ class StepOutput(NamedTuple):
class FlatlandRllibWrapper(object):
def __init__(self, rail_env: RailEnv, render: bool = False, regenerate_rail_on_reset: bool = True,
reward_range = (-float('inf'), float('inf'))
spec = None
# Set these in ALL subclasses
observation_space = None
metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second': 10
}
def __init__(self, rail_env: RailEnv, render = False, regenerate_rail_on_reset: bool = True,
regenerate_schedule_on_reset: bool = True) -> None:
super().__init__()
self._env = rail_env
......@@ -25,10 +38,7 @@ class FlatlandRllibWrapper(object):
self._regenerate_schedule_on_reset = regenerate_schedule_on_reset
self._action_space = gym.spaces.Discrete(5)
if render:
from flatland.utils.rendertools import RenderTool
self.renderer = RenderTool(self._env, gl="PILSVG")
else:
self.renderer = None
self._env.set_renderer(render)
@property
def action_space(self) -> gym.spaces.Discrete:
......@@ -42,9 +52,6 @@ class FlatlandRllibWrapper(object):
# The observation is `None` if an agent is done or malfunctioning.
obs, rewards, dones, infos = self._env.step(action_dict)
if self.renderer is not None:
self.renderer.render_env(show=True, show_predictions=True, show_observations=False)
d, r, o = dict(), dict(), dict()
for agent, done in dones.items():
if agent != '__all__' and not agent in obs:
......@@ -69,13 +76,22 @@ class FlatlandRllibWrapper(object):
assert all([x is not None for x in (d, r, o)])
return StepOutput(obs=o, reward=r, done=d, info={agent: {
'max_episode_steps': self._env._max_episode_steps,
'num_agents': self._env.get_num_agents(),
'agent_done': d[agent] and agent not in self._env.active_agents,
'agent_score': self._agent_scores[agent],
'agent_step': self._agent_steps[agent],
} for agent in o.keys()})
if isinstance(self._env, wrappers.Monitor):
return StepOutput(obs=o, reward=r, done=d, info={agent: {
'max_episode_steps': self._env.env._max_episode_steps,
'num_agents': self._env.env.get_num_agents(),
'agent_done': d[agent] and agent not in self._env.env.active_agents,
'agent_score': self._agent_scores[agent],
'agent_step': self._agent_steps[agent],
} for agent in o.keys()})
else:
return StepOutput(obs=o, reward=r, done=d, info={agent: {
'max_episode_steps': self._env._max_episode_steps,
'num_agents': self._env.get_num_agents(),
'agent_done': d[agent] and agent not in self._env.active_agents,
'agent_score': self._agent_scores[agent],
'agent_step': self._agent_steps[agent],
} for agent in o.keys()})
def reset(self, random_seed: Optional[int] = None) -> Dict[int, Any]:
self._agents_done = []
......@@ -84,6 +100,10 @@ class FlatlandRllibWrapper(object):
obs, infos = self._env.reset(regenerate_rail=self._regenerate_rail_on_reset,
regenerate_schedule=self._regenerate_schedule_on_reset,
random_seed=random_seed)
if self.renderer is not None:
self.renderer.reset()
return {k: o for k, o in obs.items() if not k == '__all__'}
def render(self,mode='human'):
return self._env.render(mode)
def close(self):
self._env.close()
......@@ -3,7 +3,8 @@ from pprint import pprint
import gym
from flatland.envs.malfunction_generators import malfunction_from_params, no_malfunction_generator
from flatland.envs.rail_env import RailEnv
# from flatland.envs.rail_env import RailEnv
from envs.flatland.utils.flatland_render_wrapper import FlatlandRenderWrapper as RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
from ray.rllib import MultiAgentEnv
......@@ -11,14 +12,26 @@ from ray.rllib import MultiAgentEnv
from envs.flatland import get_generator_config
from envs.flatland.observations import make_obs
from envs.flatland.utils.rllib_wrapper import FlatlandRllibWrapper
from gym import wrappers
from datetime import datetime
class FlatlandSparse(MultiAgentEnv):
reward_range = (-float('inf'), float('inf'))
spec = None
metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second': 10
}
def __init__(self, env_config) -> None:
super().__init__()
# TODO implement other generators
assert env_config['generator'] == 'sparse_rail_generator'
self._env_config = env_config
self._observation = make_obs(env_config['observation'], env_config.get('observation_config'))
self._config = get_generator_config(env_config['generator_config'])
......@@ -27,10 +40,11 @@ class FlatlandSparse(MultiAgentEnv):
print("=" * 50)
pprint(self._config)
print("=" * 50)
pprint(self._env_config)
print("=" * 50)
self._env = FlatlandRllibWrapper(
rail_env=self._launch(),
# render=env_config['render'], # TODO need to fix gl compatibility first
render=env_config['render'], # TODO need to fix gl compatibility first
regenerate_rail_on_reset=self._config['regenerate_rail_on_reset'],
regenerate_schedule_on_reset=self._config['regenerate_schedule_on_reset']
)
......@@ -79,9 +93,11 @@ class FlatlandSparse(MultiAgentEnv):
malfunction_generator_and_process_data=malfunction_generator,
obs_builder_object=self._observation.builder(),
remove_agents_at_target=False,
random_seed=self._config['seed']
random_seed=self._config['seed'],
# Commented below line as here the env tries different configs,
# hence opening it is wasteful, morever the render has to be closed
# use_renderer=self._env_config['render']
)
env.reset()
except ValueError as e:
logging.error("=" * 50)
......@@ -94,4 +110,13 @@ class FlatlandSparse(MultiAgentEnv):
return self._env.step(action_dict)
def reset(self):
if self._env_config['render']:
folder = "video_"+ datetime.now().strftime('%d-%b-%Y (%H:%M:%S.%f)')
self._env = wrappers.Monitor(self._env, folder, resume=True)
return self._env.reset()
def render(self,mode='human'):
return self._env.render(self._env_config['render'])
def close(self):
self._env.close()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment