From 2ac3759688a9928619d5f10410aafcd214042f46 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Wed, 6 Jan 2021 16:05:45 +0100
Subject: [PATCH] Tensorboard support added

---
 reinforcement_learning/rl_agent_test.py | 8 ++++++++
 1 file changed, 8 insertions(+)

diff --git a/reinforcement_learning/rl_agent_test.py b/reinforcement_learning/rl_agent_test.py
index 5295971..8687563 100644
--- a/reinforcement_learning/rl_agent_test.py
+++ b/reinforcement_learning/rl_agent_test.py
@@ -3,6 +3,7 @@ from collections import namedtuple
 
 import gym
 import numpy as np
+from torch.utils.tensorboard import SummaryWriter
 
 from reinforcement_learning.dddqn_policy import DDDQNPolicy
 from reinforcement_learning.ppo_agent import PPOPolicy
@@ -36,6 +37,9 @@ def cartpole(use_dddqn=False):
     episode = 0
     checkpoint_interval = 20
     scores_window = deque(maxlen=100)
+
+    writer = SummaryWriter()
+
     while True:
         episode += 1
         state = env.reset()
@@ -79,6 +83,10 @@ def cartpole(use_dddqn=False):
                                                                                                                   policy.memory)),
                   end=" ")
 
+        writer.add_scalar("CartPole/value", tot_reward, episode)
+        writer.add_scalar("CartPole/smoothed_value", np.mean(scores_window), episode)
+        writer.flush()
+
 
 if __name__ == "__main__":
     cartpole()
-- 
GitLab