From 6ebb521d03acc0d80860e682100bfd92829f53ff Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Tue, 5 Jan 2021 10:34:02 +0100 Subject: [PATCH] Policy updated --- reinforcement_learning/multi_agent_training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index 6a34939..872f695 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -19,9 +19,9 @@ from flatland.utils.rendertools import RenderTool from torch.utils.tensorboard import SummaryWriter from reinforcement_learning.dddqn_policy import DDDQNPolicy -from reinforcement_learning.ppo_agent import PPOPolicy from reinforcement_learning.deadlockavoidance_with_decision_agent import DeadLockAvoidanceWithDecisionAgent from reinforcement_learning.multi_decision_agent import MultiDecisionAgent +from reinforcement_learning.ppo_agent import PPOPolicy from utils.agent_action_config import get_flatland_full_action_size, get_action_size, map_actions, map_action from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent @@ -189,7 +189,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): policy = DDDQNPolicy(state_size, get_action_size(), train_params) # Load existing policy - if train_params.load_policy is not "": + if train_params.load_policy is not '': policy.load(train_params.load_policy) # Loads existing replay buffer -- GitLab