From a7ac4c7f20b4aaade5490af028875ae4bbb6a2ac Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Sun, 1 Sep 2019 18:41:55 -0400 Subject: [PATCH] updated handling of end of episode --- torch_training/training_navigation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 8ec2b08..4f99b48 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -47,8 +47,8 @@ def main(argv): TreeObservation = TreeObsForRailEnv(max_depth=2) # Different agent types (trains) with different speeds. - speed_ration_map = {1.: 1., # Fast passenger train - 1. / 2.: 0.0, # Fast freight train + speed_ration_map = {1.: 0., # Fast passenger train + 1. / 2.: 1.0, # Fast freight train 1. / 3.: 0.0, # Slow commuter train 1. / 4.: 0.0} # Slow freight train @@ -153,9 +153,9 @@ def main(argv): # Update replay buffer and train agent for a in range(env.get_num_agents()): if done[a]: - final_obs[a] = agent_obs[a].copy() + final_obs[a] = agent_obs_buffer[a] final_obs_next[a] = agent_next_obs[a].copy() - final_action_dict.update({a: action_dict[a]}) + final_action_dict.update({a: agent_action_buffer[a]}) if not done[a]: if agent_obs_buffer[a] is not None and register_action_state[a]: agent_delayed_next = agent_obs[a].copy() -- GitLab