From a018012d862e7cc7d895237b6c0f80477a38b55d Mon Sep 17 00:00:00 2001 From: Guillaume Mollard <guillaume.mollard2@gmail.com> Date: Thu, 13 Jun 2019 22:35:02 +0200 Subject: [PATCH] new metric --- RLLib_training/RailEnvRLLibWrapper.py | 7 +++++-- RLLib_training/train_experiment.py | 15 +++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/RLLib_training/RailEnvRLLibWrapper.py b/RLLib_training/RailEnvRLLibWrapper.py index 68a76e6..1dbdc28 100644 --- a/RLLib_training/RailEnvRLLibWrapper.py +++ b/RLLib_training/RailEnvRLLibWrapper.py @@ -6,6 +6,8 @@ from flatland.envs.generators import complex_rail_generator, random_rail_generat import numpy as np + + class RailEnvRLLibWrapper(MultiAgentEnv): def __init__(self, config): @@ -25,6 +27,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv): self.rail_generator = random_rail_generator() elif config['rail_generator'] == "load_env": self.predefined_env = True + self.rail_generator = random_rail_generator() else: raise(ValueError, f'Unknown rail generator: {config["rail_generator"]}') @@ -36,8 +39,8 @@ class RailEnvRLLibWrapper(MultiAgentEnv): prediction_builder_object=config['predictor']) if self.predefined_env: - self.env.load(config['load_env_path']) - # '/home/guillaume/EPFL/Master_Thesis/flatland/baselines/torch_training/railway/complex_scene.pkl') + #self.env.load(config['load_env_path']) + self.env.load('/home/guillaume/EPFL/Master_Thesis/flatland/baselines/torch_training/railway/complex_scene.pkl') self.width = self.env.width self.height = self.env.height diff --git a/RLLib_training/train_experiment.py b/RLLib_training/train_experiment.py index 4d3a2a9..90ce648 100644 --- a/RLLib_training/train_experiment.py +++ b/RLLib_training/train_experiment.py @@ -48,6 +48,21 @@ ModelCatalog.register_custom_model("conv_model", ConvModelGlobalObs) ray.init()#object_store_memory=150000000000, redis_max_memory=30000000000) +def on_episode_start(info): + episode = info['episode'] + map_width = info['env'].envs[0].width + map_height = info['env'].envs[0].height + episode.horizon = map_width + map_height + + +def on_episode_step(info): + episode = info['episode'] + + +def on_episode_end(info): + episode = info['episode'] + + def train(config, reporter): print('Init Env') -- GitLab