Skip to content
Snippets Groups Projects
Commit 2ac37596 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

Tensorboard support added

parent 7365be1a
No related branches found
No related tags found
No related merge requests found
...@@ -3,6 +3,7 @@ from collections import namedtuple ...@@ -3,6 +3,7 @@ from collections import namedtuple
import gym import gym
import numpy as np import numpy as np
from torch.utils.tensorboard import SummaryWriter
from reinforcement_learning.dddqn_policy import DDDQNPolicy from reinforcement_learning.dddqn_policy import DDDQNPolicy
from reinforcement_learning.ppo_agent import PPOPolicy from reinforcement_learning.ppo_agent import PPOPolicy
...@@ -36,6 +37,9 @@ def cartpole(use_dddqn=False): ...@@ -36,6 +37,9 @@ def cartpole(use_dddqn=False):
episode = 0 episode = 0
checkpoint_interval = 20 checkpoint_interval = 20
scores_window = deque(maxlen=100) scores_window = deque(maxlen=100)
writer = SummaryWriter()
while True: while True:
episode += 1 episode += 1
state = env.reset() state = env.reset()
...@@ -79,6 +83,10 @@ def cartpole(use_dddqn=False): ...@@ -79,6 +83,10 @@ def cartpole(use_dddqn=False):
policy.memory)), policy.memory)),
end=" ") end=" ")
writer.add_scalar("CartPole/value", tot_reward, episode)
writer.add_scalar("CartPole/smoothed_value", np.mean(scores_window), episode)
writer.flush()
if __name__ == "__main__": if __name__ == "__main__":
cartpole() cartpole()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment