From c8bd78337bce7864e4badfd765377f20f94844e3 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Wed, 17 Jul 2019 12:28:45 -0400
Subject: [PATCH] minor updates and new training checkpoints

---
 torch_training/multi_agent_inference.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py
index 71ca244..4ee79bb 100644
--- a/torch_training/multi_agent_inference.py
+++ b/torch_training/multi_agent_inference.py
@@ -3,7 +3,7 @@ from collections import deque
 
 import numpy as np
 import torch
-from flatland.envs.generators import complex_rail_generator
+from flatland.envs.generators import rail_from_file
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
@@ -16,7 +16,7 @@ from utils.observation_utils import norm_obs_clip, split_tree
 
 random.seed(3)
 np.random.seed(2)
-"""
+
 file_name = "./railway/flatland.pkl"
 env = RailEnv(width=10,
               height=20,
@@ -41,7 +41,7 @@ env = RailEnv(width=x_dim,
               obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()),
               number_of_agents=n_agents)
 env.reset(True, True)
-
+"""
 tree_depth = 3
 observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv())
 env_renderer = RenderTool(env, gl="PILSVG", )
@@ -81,7 +81,7 @@ for trials in range(1, n_trials + 1):
     # Reset environment
     obs = env.reset(True, True)
 
-    env_renderer.set_new_rail()
+    env_renderer.reset()
 
     for a in range(env.get_num_agents()):
         data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=num_features_per_node,
-- 
GitLab