From 44844ffa3d34c93520cae3632edf2f6126254a1d Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Mon, 7 Oct 2019 16:27:24 -0400
Subject: [PATCH] added shortest path predictor to multi agent training and
 inference

---
 torch_training/multi_agent_inference.py | 14 +++++++-------
 torch_training/multi_agent_training.py  |  4 +++-
 torch_training/render_agent_behavior.py |  2 --
 torch_training/training_navigation.py   |  4 ++--
 4 files changed, 12 insertions(+), 12 deletions(-)

diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py
index f2458c2..1866443 100644
--- a/torch_training/multi_agent_inference.py
+++ b/torch_training/multi_agent_inference.py
@@ -30,7 +30,7 @@ y_dim = env.height
 # Parameters for the Environment
 x_dim = 25
 y_dim = 25
-n_agents = 1
+n_agents = 10
 
 # We are training an Agent using the Tree Observation with depth 2
 observation_builder = TreeObsForRailEnv(max_depth=2)
@@ -43,13 +43,13 @@ stochastic_data = {'prop_malfunction': 0.0,  # Percentage of defective agents
                    }
 
 # Custom observation builder
-TreeObservation = TreeObsForRailEnv(max_depth=2)
+TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
 
 # Different agent types (trains) with different speeds.
-speed_ration_map = {1.: 1.,  # Fast passenger train
-                    1. / 2.: 0.0,  # Fast freight train
-                    1. / 3.: 0.0,  # Slow commuter train
-                    1. / 4.: 0.0}  # Slow freight train
+speed_ration_map = {1.: 0.25,  # Fast passenger train
+                    1. / 2.: 0.25,  # Fast freight train
+                    1. / 3.: 0.25,  # Slow commuter train
+                    1. / 4.: 0.25}  # Slow freight train
 
 env = RailEnv(width=x_dim,
               height=y_dim,
@@ -93,7 +93,7 @@ action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size)
-with path(torch_training.Nets, "avoider_checkpoint1000.pth") as file_in:
+with path(torch_training.Nets, "avoider_checkpoint100.pth") as file_in:
     agent.qnetwork_local.load_state_dict(torch.load(file_in))
 
 record_images = False
diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py
index ed20ea6..0dacc9c 100644
--- a/torch_training/multi_agent_training.py
+++ b/torch_training/multi_agent_training.py
@@ -14,6 +14,8 @@ import torch
 from torch_training.dueling_double_dqn import Agent
 
 from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import sparse_rail_generator
 from flatland.envs.schedule_generators import sparse_schedule_generator
@@ -48,7 +50,7 @@ def main(argv):
                        }
 
     # Custom observation builder
-    TreeObservation = TreeObsForRailEnv(max_depth=2)
+    TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
 
     # Different agent types (trains) with different speeds.
     speed_ration_map = {1.: 0.25,  # Fast passenger train
diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py
index d599bcf..62cfa49 100644
--- a/torch_training/render_agent_behavior.py
+++ b/torch_training/render_agent_behavior.py
@@ -4,7 +4,6 @@ from collections import deque
 import numpy as np
 import torch
 from flatland.envs.observations import TreeObsForRailEnv
-from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import sparse_rail_generator
 from flatland.envs.schedule_generators import sparse_schedule_generator
@@ -67,7 +66,6 @@ env = RailEnv(width=x_dim,
               obs_builder_object=TreeObservation)
 env.reset(True, True)
 
-observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
 env_renderer = RenderTool(env, gl="PILSVG", )
 num_features_per_node = env.obs_builder.observation_dim
 
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 8e46796..b0942ee 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -13,8 +13,8 @@ import numpy as np
 import torch
 from torch_training.dueling_double_dqn import Agent
 
-from flatland.envs.observations import TreeObsForRailEnv
-from flatland.envs.rail_env import RailEnv
+flatland.envs.rail_env
+import RailEnv
 from flatland.envs.rail_generators import sparse_rail_generator
 from flatland.envs.schedule_generators import sparse_schedule_generator
 from flatland.utils.rendertools import RenderTool
-- 
GitLab