Commit 92b9e365 authored by nilabha's avatar nilabha

Changes to include flatland base class for flatland env files

parent 6d6c5b4f
......@@ -3,7 +3,8 @@ import random
from typing import NamedTuple
from flatland.envs.malfunction_generators import malfunction_from_params
from flatland.envs.rail_env import RailEnv
# from flatland.envs.rail_env import RailEnv
from envs.flatland.utils.gym_env_wrappers import FlatlandRenderWrapper as RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
......
from gym.wrappers import monitor
from ray.rllib import MultiAgentEnv
def _after_step(self, observation, reward, done, info):
if not self.enabled: return done
if type(done)== dict:
_done_check = done['__all__']
else:
_done_check = done
if _done_check and self.env_semantics_autoreset:
# For envs with BlockingReset wrapping VNCEnv, this observation will be the first one of the new episode
self.reset_video_recorder()
self.episode_id += 1
self._flush()
# Record stats - Disabled as it causes error in multi-agent set up
# self.stats_recorder.after_step(observation, reward, done, info)
# Record video
self.video_recorder.capture_frame()
return done
class FlatlandBase(MultiAgentEnv):
reward_range = (-float('inf'), float('inf'))
spec = None
metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second': 10,
'semantics.autoreset': True
}
def step(self, action_dict):
obs, all_rewards, done, info = self._env.step(action_dict)
if done['__all__']:
self.close()
return obs, all_rewards, done, info
def reset(self, *args, **kwargs):
if self._env_config.get('render', None):
env_name="flatland"
monitor.FILE_PREFIX = env_name
folder = self._env_config.get('video_dir',env_name)
monitor.Monitor._after_step =_after_step
self._env = monitor.Monitor(self._env, folder, resume=True)
return self._env.reset(*args, **kwargs)
def render(self,mode='human'):
return self._env.render(self._env_config.get('render'))
def close(self):
self._env.close()
import random
import gym
from ray.rllib import MultiAgentEnv
from envs.flatland.utils.env_generators import random_sparse_env_small
from envs.flatland.observations import make_obs
from envs.flatland.utils.gym_env import FlatlandGymEnv
from envs.flatland.utils.gym_env_wrappers import SkipNoChoiceCellsWrapper, AvailableActionsWrapper
from envs.flatland_base import FlatlandBase
class FlatlandRandomSparseSmall(FlatlandBase):
class FlatlandRandomSparseSmall(MultiAgentEnv):
def __init__(self, env_config) -> None:
super().__init__()
self._env_config = env_config
......@@ -27,7 +27,7 @@ class FlatlandRandomSparseSmall(MultiAgentEnv):
self._env = FlatlandGymEnv(
rail_env=self._launch(),
observation_space=self._observation.observation_space(),
# render=env_config['render'], # TODO need to fix gl compatibility first
render=env_config.get('render'),
regenerate_rail_on_reset=env_config['regenerate_rail_on_reset'],
regenerate_schedule_on_reset=env_config['regenerate_schedule_on_reset']
)
......@@ -67,9 +67,6 @@ class FlatlandRandomSparseSmall(MultiAgentEnv):
raise RuntimeError(f"Unable to launch env within {max_tries} tries.")
return env
def step(self, action_dict):
return self._env.step(action_dict)
def reset(self):
if self._test or (
self._env_config['reset_env_freq'] is not None
......@@ -78,4 +75,4 @@ class FlatlandRandomSparseSmall(MultiAgentEnv):
):
self._env.env = self._launch()
self._num_resets += 1
return self._env.reset(random_seed=self._next_test_seed if self._test else self._generate_random_seed())
return super().reset(random_seed=self._next_test_seed if self._test else self._generate_random_seed())
......@@ -8,47 +8,14 @@ from envs.flatland.utils.gym_env_wrappers import FlatlandRenderWrapper as RailEn
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
from ray.rllib import MultiAgentEnv
from envs.flatland import get_generator_config
from envs.flatland.observations import make_obs
from envs.flatland.utils.gym_env import FlatlandGymEnv
from envs.flatland.utils.gym_env_wrappers import AvailableActionsWrapper, SkipNoChoiceCellsWrapper
from gym.wrappers import monitor
from datetime import datetime
import time
import os
def _after_step(self, observation, reward, done, info):
if not self.enabled: return done
if type(done)== dict:
_done_check = done['__all__']
else:
_done_check = done
if _done_check and self.env_semantics_autoreset:
# For envs with BlockingReset wrapping VNCEnv, this observation will be the first one of the new episode
self.reset_video_recorder()
self.episode_id += 1
self._flush()
# Record stats - Disabled as it causes error in multi-agent set up
# self.stats_recorder.after_step(observation, reward, done, info)
# Record video
self.video_recorder.capture_frame()
return done
class FlatlandSparse(MultiAgentEnv):
reward_range = (-float('inf'), float('inf'))
spec = None
metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second': 10,
'semantics.autoreset': True
}
from envs.flatland_base import FlatlandBase
class FlatlandSparse(FlatlandBase):
def __init__(self, env_config) -> None:
super().__init__()
......@@ -68,7 +35,7 @@ class FlatlandSparse(MultiAgentEnv):
self._env = FlatlandGymEnv(
rail_env=self._launch(),
observation_space=self._observation.observation_space(),
render=env_config.get('render'), # TODO need to fix gl compatibility first
render=env_config.get('render'),
regenerate_rail_on_reset=self._config['regenerate_rail_on_reset'],
regenerate_schedule_on_reset=self._config['regenerate_schedule_on_reset']
)
......@@ -135,24 +102,3 @@ class FlatlandSparse(MultiAgentEnv):
logging.error("=" * 50)
return env
def step(self, action_dict):
obs, all_rewards, done, info = self._env.step(action_dict)
if done['__all__']:
self.close()
return obs, all_rewards, done, info
def reset(self):
if self._env_config.get('render', None):
env_name="flatland"
monitor.FILE_PREFIX = env_name
folder = self._env_config.get('video_dir',env_name)
monitor.Monitor._after_step =_after_step
self._env = monitor.Monitor(self._env, folder, resume=True)
return self._env.reset()
def render(self,mode='human'):
return self._env.render(self._env_config.get('render'))
def close(self):
self._env.close()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment