From 2ac3759688a9928619d5f10410aafcd214042f46 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Wed, 6 Jan 2021 16:05:45 +0100 Subject: [PATCH] Tensorboard support added --- reinforcement_learning/rl_agent_test.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/reinforcement_learning/rl_agent_test.py b/reinforcement_learning/rl_agent_test.py index 5295971..8687563 100644 --- a/reinforcement_learning/rl_agent_test.py +++ b/reinforcement_learning/rl_agent_test.py @@ -3,6 +3,7 @@ from collections import namedtuple import gym import numpy as np +from torch.utils.tensorboard import SummaryWriter from reinforcement_learning.dddqn_policy import DDDQNPolicy from reinforcement_learning.ppo_agent import PPOPolicy @@ -36,6 +37,9 @@ def cartpole(use_dddqn=False): episode = 0 checkpoint_interval = 20 scores_window = deque(maxlen=100) + + writer = SummaryWriter() + while True: episode += 1 state = env.reset() @@ -79,6 +83,10 @@ def cartpole(use_dddqn=False): policy.memory)), end=" ") + writer.add_scalar("CartPole/value", tot_reward, episode) + writer.add_scalar("CartPole/smoothed_value", np.mean(scores_window), episode) + writer.flush() + if __name__ == "__main__": cartpole() -- GitLab