From df0b0ef116980c363e4f442eed7c376efbd8af96 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Sun, 6 Oct 2019 11:20:57 -0400
Subject: [PATCH] updated training file

---
 torch_training/render_agent_behavior.py |  8 ++++----
 torch_training/training_navigation.py   | 25 +++++++++++--------------
 2 files changed, 15 insertions(+), 18 deletions(-)

diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py
index 6802a22..93f9f12 100644
--- a/torch_training/render_agent_behavior.py
+++ b/torch_training/render_agent_behavior.py
@@ -28,8 +28,8 @@ y_dim = env.height
 """
 
 # Parameters for the Environment
-x_dim = 20
-y_dim = 20
+x_dim = 25
+y_dim = 25
 n_agents = 1
 n_goals = 5
 min_dist = 5
@@ -48,9 +48,9 @@ stochastic_data = {'prop_malfunction': 0.0,  # Percentage of defective agents
 TreeObservation = TreeObsForRailEnv(max_depth=2)
 
 # Different agent types (trains) with different speeds.
-speed_ration_map = {1.: 1.,  # Fast passenger train
+speed_ration_map = {1.: 0.,  # Fast passenger train
                     1. / 2.: 0.0,  # Fast freight train
-                    1. / 3.: 0.0,  # Slow commuter train
+                    1. / 3.: 1.0,  # Slow commuter train
                     1. / 4.: 0.0}  # Slow freight train
 
 env = RailEnv(width=x_dim,
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 3220b94..229c804 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -35,8 +35,8 @@ def main(argv):
     np.random.seed(1)
 
     # Parameters for the Environment
-    x_dim = 20
-    y_dim = 20
+    x_dim = 30
+    y_dim = 30
     n_agents = 1
 
 
@@ -63,7 +63,7 @@ def main(argv):
                                                        seed=1,  # Random seed
                                                        grid_mode=False,
                                                        max_rails_between_cities=2,
-                                                       max_rails_in_city=2),
+                                                       max_rails_in_city=3),
                   schedule_generator=sparse_schedule_generator(speed_ration_map),
                   number_of_agents=n_agents,
                   stochastic_data=stochastic_data,  # Malfunction data generator
@@ -105,7 +105,8 @@ def main(argv):
     agent_obs = [None] * env.get_num_agents()
     agent_next_obs = [None] * env.get_num_agents()
     agent_obs_buffer = [None] * env.get_num_agents()
-    agent_action_buffer = [None] * env.get_num_agents()
+    agent_action_buffer = [2] * env.get_num_agents()
+    agent_done_buffer = [False] * env.get_num_agents()
     cummulated_reward = np.zeros(env.get_num_agents())
 
     # Now we load a Double dueling DQN agent
@@ -115,7 +116,7 @@ def main(argv):
 
         # Reset environment
         obs, info = env.reset(True, True)
-
+        env_renderer.reset()
         # Build agent specific observations
         for a in range(env.get_num_agents()):
             agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
@@ -132,36 +133,32 @@ def main(argv):
                 if info['action_required'][a]:
                     action = agent.act(agent_obs[a], eps=eps)
                     action_prob[action] += 1
-                    if step == 0:
-                        agent_action_buffer[a] = action
                 else:
                     action = 0
                 action_dict.update({a: action})
 
             # Environment step
             next_obs, all_rewards, done, info = env.step(action_dict)
-
             # Build agent specific observations and normalize
             for a in range(env.get_num_agents()):
                 # Penalize waiting in order to get agent to move
                 if env.agents[a].status == 0:
                     all_rewards[a] -= 1
-
-                agent_next_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10)
+                if info['action_required'][a]:
+                    agent_next_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10)
                 cummulated_reward[a] += all_rewards[a]
 
             # Update replay buffer and train agent
             for a in range(env.get_num_agents()):
-                if (agent_obs_buffer[a] is not None and info['action_required'][a] and env.agents[a].status != 3) or \
-                        env.agents[a].status == 2:
+                if (info['action_required'][a] and env.agents[a].status != 3) or env.agents[a].status == 2:
 
-                    agent_delayed_next = agent_obs[a].copy()
                     agent.step(agent_obs_buffer[a], agent_action_buffer[a], all_rewards[a],
-                               agent_delayed_next, done[a])
+                               agent_obs[a], agent_done_buffer[a])
                     cummulated_reward[a] = 0.
                 if info['action_required'][a]:
                     agent_obs_buffer[a] = agent_obs[a].copy()
                     agent_action_buffer[a] = action_dict[a]
+                    agent_done_buffer[a] = done[a]
 
                 score += all_rewards[a] / env.get_num_agents()
 
-- 
GitLab