Commit 551cd808 authored by MasterScrat's avatar MasterScrat

Generator config registry

parent 43f619fb
import importlib
import os
import humps
import yaml
GENERATOR_CONFIG_REGISTRY = {}
config_folder = os.path.join(os.path.dirname(__file__), "generator_configs")
for file in os.listdir(config_folder):
if file.endswith('.yaml') and not file.startswith('_'):
basename = os.path.basename(file)
filename = basename.replace(".yaml", "")
with open(os.path.join(config_folder, file)) as f:
GENERATOR_CONFIG_REGISTRY[filename] = yaml.safe_load(f)
print("- Successfully Loaded Generator Config {} from {}".format(
filename, basename
))
......@@ -49,5 +49,5 @@ for file in os.listdir(os.path.dirname(__file__)):
module = importlib.import_module(f'.{file[:-3]}', package=__name__)
print("- Successfully Loaded Observation class {} from {}".format(
class_name, os.path.basename(basename)
class_name, basename
))
......@@ -15,22 +15,9 @@ from envs.flatland.utils.rllib_wrapper import FlatlandRllibWrapper
class FlatlandSparse(MultiAgentEnv):
def __init__(self, env_config) -> None:
super().__init__()
self._observation = make_obs(env_config['observation'], env_config.get('observation_config'))
self._config = env_config
print("#" * 50)
print(self._config)
print("#" * 50)
print(os.getcwd())
loaded_config = {}
with open(env_config['generator_config']) as f:
loaded_config = yaml.safe_load(f)
print(loaded_config)
self._config = loaded_config
print("#" * 50)
self._config = yaml.safe_load(f)
self._env = FlatlandRllibWrapper(
rail_env=self._launch(),
render=env_config['render'],
......
......@@ -90,7 +90,7 @@ def load_envs(local_dir="."):
))
# Finally Register Env in Tune
registry.register_env(env_name, lambda config: env(config))
print("- Successfully Loaded class {} from {}".format(
print("- Successfully Loaded Environment class {} from {}".format(
class_name, os.path.basename(_file_path)
))
......@@ -126,6 +126,6 @@ def load_models(local_dir="."):
))
# Finally Register Model in ModelCatalog
ModelCatalog.register_custom_model(model_name, CustomModel)
print("- Successfully Loaded custom Model class {} from {}".format(
print("- Successfully Loaded Model class {} from {}".format(
class_name, os.path.basename(_file_path)
))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment