From a7ac4c7f20b4aaade5490af028875ae4bbb6a2ac Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Sun, 1 Sep 2019 18:41:55 -0400
Subject: [PATCH] updated handling of end of episode

---
 torch_training/training_navigation.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 8ec2b08..4f99b48 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -47,8 +47,8 @@ def main(argv):
     TreeObservation = TreeObsForRailEnv(max_depth=2)
 
     # Different agent types (trains) with different speeds.
-    speed_ration_map = {1.: 1.,  # Fast passenger train
-                        1. / 2.: 0.0,  # Fast freight train
+    speed_ration_map = {1.: 0.,  # Fast passenger train
+                        1. / 2.: 1.0,  # Fast freight train
                         1. / 3.: 0.0,  # Slow commuter train
                         1. / 4.: 0.0}  # Slow freight train
 
@@ -153,9 +153,9 @@ def main(argv):
             # Update replay buffer and train agent
             for a in range(env.get_num_agents()):
                 if done[a]:
-                    final_obs[a] = agent_obs[a].copy()
+                    final_obs[a] = agent_obs_buffer[a]
                     final_obs_next[a] = agent_next_obs[a].copy()
-                    final_action_dict.update({a: action_dict[a]})
+                    final_action_dict.update({a: agent_action_buffer[a]})
                 if not done[a]:
                     if agent_obs_buffer[a] is not None and register_action_state[a]:
                         agent_delayed_next = agent_obs[a].copy()
-- 
GitLab