From 9fd05a66e9cc29d0821d5f4c21f8de48a7e67887 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Thu, 18 Jul 2019 16:38:43 -0400
Subject: [PATCH] enhanced functionality

---
 sequential_agent/run_test.py            | 9 +++++----
 sequential_agent/simple_order_agent.py  | 2 +-
 torch_training/multi_agent_inference.py | 1 -
 3 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/sequential_agent/run_test.py b/sequential_agent/run_test.py
index 970d6aa..d0b9ce7 100644
--- a/sequential_agent/run_test.py
+++ b/sequential_agent/run_test.py
@@ -8,11 +8,11 @@ import numpy as np
 
 np.random.seed(2)
 """
-file_name = "./railway/complex_scene.pkl"
+file_name = "../torch_training/railway/complex_scene.pkl"
 env = RailEnv(width=10,
               height=20,
               rail_generator=rail_from_file(file_name),
-              obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()))
+              obs_builder_object=TreeObsForRailEnv(max_depth=1, predictor=ShortestPathPredictorForRailEnv()))
 x_dim = env.width
 y_dim = env.height
 
@@ -38,8 +38,8 @@ observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestP
 env_renderer = RenderTool(env, gl="PILSVG", )
 handle = env.get_agent_handles()
 n_trials = 1
-max_steps = 3 * (env.height + env.width)
-record_images = True
+max_steps = 100 * (env.height + env.width)
+record_images = False
 agent = OrderedAgent()
 action_dict = dict()
 
@@ -63,6 +63,7 @@ for trials in range(1, n_trials + 1):
         for a in range(env.get_num_agents()):
             if done[a]:
                 acting_agent += 1
+                print(acting_agent)
             if a == acting_agent:
                 action = agent.act(obs[a], eps=0)
             else:
diff --git a/sequential_agent/simple_order_agent.py b/sequential_agent/simple_order_agent.py
index 6e888c5..3feff35 100644
--- a/sequential_agent/simple_order_agent.py
+++ b/sequential_agent/simple_order_agent.py
@@ -18,7 +18,7 @@ class OrderedAgent:
         min_dist = min_lt(distance, 0)
         min_direction = np.where(distance == min_dist)
         if len(min_direction[0]) > 1:
-            return min_direction[0][0] + 1
+            return min_direction[0][-1] + 1
         return min_direction[0] + 1
 
     def step(self, memories):
diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py
index e8fd6d4..84a0846 100644
--- a/torch_training/multi_agent_inference.py
+++ b/torch_training/multi_agent_inference.py
@@ -98,7 +98,6 @@ for trials in range(1, n_trials + 1):
         for a in range(env.get_num_agents()):
             action = agent.act(agent_obs[a], eps=0)
             action_dict.update({a: action})
-
         # Environment step
 
         next_obs, all_rewards, done, _ = env.step(action_dict)
-- 
GitLab