From a018012d862e7cc7d895237b6c0f80477a38b55d Mon Sep 17 00:00:00 2001
From: Guillaume Mollard <guillaume.mollard2@gmail.com>
Date: Thu, 13 Jun 2019 22:35:02 +0200
Subject: [PATCH] new metric

---
 RLLib_training/RailEnvRLLibWrapper.py |  7 +++++--
 RLLib_training/train_experiment.py    | 15 +++++++++++++++
 2 files changed, 20 insertions(+), 2 deletions(-)

diff --git a/RLLib_training/RailEnvRLLibWrapper.py b/RLLib_training/RailEnvRLLibWrapper.py
index 68a76e6..1dbdc28 100644
--- a/RLLib_training/RailEnvRLLibWrapper.py
+++ b/RLLib_training/RailEnvRLLibWrapper.py
@@ -6,6 +6,8 @@ from flatland.envs.generators import complex_rail_generator, random_rail_generat
 import numpy as np
 
 
+
+
 class RailEnvRLLibWrapper(MultiAgentEnv):
 
     def __init__(self, config):
@@ -25,6 +27,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
             self.rail_generator = random_rail_generator()
         elif config['rail_generator'] == "load_env":
             self.predefined_env = True
+            self.rail_generator = random_rail_generator()
 
         else:
             raise(ValueError, f'Unknown rail generator: {config["rail_generator"]}')
@@ -36,8 +39,8 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
                 prediction_builder_object=config['predictor'])
 
         if self.predefined_env:
-            self.env.load(config['load_env_path'])
-                # '/home/guillaume/EPFL/Master_Thesis/flatland/baselines/torch_training/railway/complex_scene.pkl')
+            #self.env.load(config['load_env_path'])
+            self.env.load('/home/guillaume/EPFL/Master_Thesis/flatland/baselines/torch_training/railway/complex_scene.pkl')
 
         self.width = self.env.width
         self.height = self.env.height
diff --git a/RLLib_training/train_experiment.py b/RLLib_training/train_experiment.py
index 4d3a2a9..90ce648 100644
--- a/RLLib_training/train_experiment.py
+++ b/RLLib_training/train_experiment.py
@@ -48,6 +48,21 @@ ModelCatalog.register_custom_model("conv_model", ConvModelGlobalObs)
 ray.init()#object_store_memory=150000000000, redis_max_memory=30000000000)
 
 
+def on_episode_start(info):
+    episode = info['episode']
+    map_width = info['env'].envs[0].width
+    map_height = info['env'].envs[0].height
+    episode.horizon = map_width + map_height
+    
+
+def on_episode_step(info):
+    episode = info['episode']
+
+
+def on_episode_end(info):
+    episode = info['episode']
+
+
 def train(config, reporter):
     print('Init Env')
 
-- 
GitLab