Commit 43f619fb authored by MasterScrat's avatar MasterScrat

Implementing generator config registry, WIP

parent 0b50abd8
width: 32
height: 32
number_of_agents: 10
max_num_cities: 3
grid_mode: False
max_rails_between_cities: 2
max_rails_in_city: 3
malfunction_rate: 8000
malfunction_min_duration: 15
malfunction_max_duration: 50
# speed ratio map: keys must be strings but will be converted to float later!
speed_ratio_map:
'1.0': 0.25
'0.5': 0.25
'0.3333333': 0.25
'0.25': 0.25
seed: 0
\ No newline at end of file
......@@ -3,6 +3,7 @@ import os
from abc import ABC, abstractmethod
import gym
import humps
from flatland.core.env_observation_builder import ObservationBuilder
......@@ -42,4 +43,11 @@ def make_obs(name: str, config, *args, **kwargs) -> Observation:
# automatically import any Python files in the obs/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
module = importlib.import_module(f'.{file[:-3]}', __name__)
basename = os.path.basename(file)
filename = basename.replace(".py", "")
class_name = humps.pascalize(filename)
module = importlib.import_module(f'.{file[:-3]}', package=__name__)
print("- Successfully Loaded Observation class {} from {}".format(
class_name, os.path.basename(basename)
))
......@@ -3,9 +3,9 @@ import random
import gym
from ray.rllib import MultiAgentEnv
from envs.flatland.env_generators import random_sparse_env_small
from envs.flatland.utils.env_generators import random_sparse_env_small
from envs.flatland.observations import make_obs
from envs.flatland.rllib_wrapper import FlatlandRllibWrapper
from envs.flatland.utils.rllib_wrapper import FlatlandRllibWrapper
class FlatlandRandomSparseSmall(MultiAgentEnv):
......
import os
import gym
import yaml
from flatland.envs.malfunction_generators import malfunction_from_params
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
......@@ -6,17 +9,34 @@ from flatland.envs.schedule_generators import sparse_schedule_generator
from ray.rllib import MultiAgentEnv
from envs.flatland.observations import make_obs
from envs.flatland.rllib_wrapper import FlatlandRllibWrapper
from envs.flatland.utils.rllib_wrapper import FlatlandRllibWrapper
class FlatlandSparse(MultiAgentEnv):
def __init__(self, env_config) -> None:
super().__init__()
self._config = env_config
self._observation = make_obs(env_config['observation'], env_config.get('observation_config'))
self._env = FlatlandRllibWrapper(rail_env=self._launch(), render=env_config['render'],
regenerate_rail_on_reset=env_config['regenerate_rail_on_reset'],
regenerate_schedule_on_reset=env_config['regenerate_schedule_on_reset'])
self._config = env_config
print("#" * 50)
print(self._config)
print("#" * 50)
print(os.getcwd())
loaded_config = {}
with open(env_config['generator_config']) as f:
loaded_config = yaml.safe_load(f)
print(loaded_config)
self._config = loaded_config
print("#" * 50)
self._env = FlatlandRllibWrapper(
rail_env=self._launch(),
render=env_config['render'],
regenerate_rail_on_reset=env_config['regenerate_rail_on_reset'],
regenerate_schedule_on_reset=env_config['regenerate_schedule_on_reset']
)
@property
def observation_space(self) -> gym.spaces.Space:
......@@ -27,14 +47,19 @@ class FlatlandSparse(MultiAgentEnv):
return self._env.action_space
def _launch(self):
rail_generator = sparse_rail_generator(seed=self._config['seed'], max_num_cities=self._config['max_num_cities'],
grid_mode=self._config['grid_mode'],
max_rails_between_cities=self._config['max_rails_between_cities'],
max_rails_in_city=self._config['max_rails_in_city'])
stochastic_data = {'malfunction_rate': self._config['malfunction_rate'],
'min_duration': self._config['malfunction_min_duration'],
'max_duration': self._config['malfunction_max_duration']}
rail_generator = sparse_rail_generator(
seed=self._config['seed'],
max_num_cities=self._config['max_num_cities'],
grid_mode=self._config['grid_mode'],
max_rails_between_cities=self._config['max_rails_between_cities'],
max_rails_in_city=self._config['max_rails_in_city']
)
stochastic_data = {
'malfunction_rate': self._config['malfunction_rate'],
'min_duration': self._config['malfunction_min_duration'],
'max_duration': self._config['malfunction_max_duration']
}
schedule_generator = sparse_schedule_generator({float(k): float(v)
for k, v in self._config['speed_ratio_map'].items()})
......
#!/usr/bin/env python
import os
from pathlib import Path
import ray
import yaml
from pathlib import Path
from ray.cluster_utils import Cluster
from ray.rllib.evaluation import MultiAgentEpisode
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.tune import tune, run_experiments
from ray.tune import run_experiments
from ray.tune.logger import TBXLogger
from ray.tune.resources import resources_to_json
from ray.tune.tune import _make_scheduler
from argparser import create_parser
from utils.loader import load_envs, load_models
# Custom wandb logger with hotfix to allow custom callbacks
from wandblogger import WandbLogger
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment