diff --git a/RLLib_training/RailEnvRLLibWrapper.py b/RLLib_training/RailEnvRLLibWrapper.py index 68a76e6102a8fd5897379e00f07da3c26a774cc3..1dbdc28727c744b65c9e8591e34cb14ef7e42199 100644 --- a/RLLib_training/RailEnvRLLibWrapper.py +++ b/RLLib_training/RailEnvRLLibWrapper.py @@ -6,6 +6,8 @@ from flatland.envs.generators import complex_rail_generator, random_rail_generat import numpy as np + + class RailEnvRLLibWrapper(MultiAgentEnv): def __init__(self, config): @@ -25,6 +27,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv): self.rail_generator = random_rail_generator() elif config['rail_generator'] == "load_env": self.predefined_env = True + self.rail_generator = random_rail_generator() else: raise(ValueError, f'Unknown rail generator: {config["rail_generator"]}') @@ -36,8 +39,8 @@ class RailEnvRLLibWrapper(MultiAgentEnv): prediction_builder_object=config['predictor']) if self.predefined_env: - self.env.load(config['load_env_path']) - # '/home/guillaume/EPFL/Master_Thesis/flatland/baselines/torch_training/railway/complex_scene.pkl') + #self.env.load(config['load_env_path']) + self.env.load('/home/guillaume/EPFL/Master_Thesis/flatland/baselines/torch_training/railway/complex_scene.pkl') self.width = self.env.width self.height = self.env.height diff --git a/RLLib_training/train_experiment.py b/RLLib_training/train_experiment.py index 4d3a2a92a33d888f7268f9686277b3f108149c7c..90ce6484513a4b3e0ea34a80cf3ccdca7c99b32c 100644 --- a/RLLib_training/train_experiment.py +++ b/RLLib_training/train_experiment.py @@ -48,6 +48,21 @@ ModelCatalog.register_custom_model("conv_model", ConvModelGlobalObs) ray.init()#object_store_memory=150000000000, redis_max_memory=30000000000) +def on_episode_start(info): + episode = info['episode'] + map_width = info['env'].envs[0].width + map_height = info['env'].envs[0].height + episode.horizon = map_width + map_height + + +def on_episode_step(info): + episode = info['episode'] + + +def on_episode_end(info): + episode = info['episode'] + + def train(config, reporter): print('Init Env')