Skip to content
Snippets Groups Projects
Commit a018012d authored by gmollard's avatar gmollard
Browse files

new metric

parent 2f1e8af1
No related branches found
No related tags found
No related merge requests found
...@@ -6,6 +6,8 @@ from flatland.envs.generators import complex_rail_generator, random_rail_generat ...@@ -6,6 +6,8 @@ from flatland.envs.generators import complex_rail_generator, random_rail_generat
import numpy as np import numpy as np
class RailEnvRLLibWrapper(MultiAgentEnv): class RailEnvRLLibWrapper(MultiAgentEnv):
def __init__(self, config): def __init__(self, config):
...@@ -25,6 +27,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv): ...@@ -25,6 +27,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
self.rail_generator = random_rail_generator() self.rail_generator = random_rail_generator()
elif config['rail_generator'] == "load_env": elif config['rail_generator'] == "load_env":
self.predefined_env = True self.predefined_env = True
self.rail_generator = random_rail_generator()
else: else:
raise(ValueError, f'Unknown rail generator: {config["rail_generator"]}') raise(ValueError, f'Unknown rail generator: {config["rail_generator"]}')
...@@ -36,8 +39,8 @@ class RailEnvRLLibWrapper(MultiAgentEnv): ...@@ -36,8 +39,8 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
prediction_builder_object=config['predictor']) prediction_builder_object=config['predictor'])
if self.predefined_env: if self.predefined_env:
self.env.load(config['load_env_path']) #self.env.load(config['load_env_path'])
# '/home/guillaume/EPFL/Master_Thesis/flatland/baselines/torch_training/railway/complex_scene.pkl') self.env.load('/home/guillaume/EPFL/Master_Thesis/flatland/baselines/torch_training/railway/complex_scene.pkl')
self.width = self.env.width self.width = self.env.width
self.height = self.env.height self.height = self.env.height
......
...@@ -48,6 +48,21 @@ ModelCatalog.register_custom_model("conv_model", ConvModelGlobalObs) ...@@ -48,6 +48,21 @@ ModelCatalog.register_custom_model("conv_model", ConvModelGlobalObs)
ray.init()#object_store_memory=150000000000, redis_max_memory=30000000000) ray.init()#object_store_memory=150000000000, redis_max_memory=30000000000)
def on_episode_start(info):
episode = info['episode']
map_width = info['env'].envs[0].width
map_height = info['env'].envs[0].height
episode.horizon = map_width + map_height
def on_episode_step(info):
episode = info['episode']
def on_episode_end(info):
episode = info['episode']
def train(config, reporter): def train(config, reporter):
print('Init Env') print('Init Env')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment