diff --git a/reinforcement_learning/rl_agent_test.py b/reinforcement_learning/rl_agent_test.py index 529597171cbe6898aa70f6e163df458dbb4257fc..8687563f1886f549b415abbf8a3ff3b60e8bdb12 100644 --- a/reinforcement_learning/rl_agent_test.py +++ b/reinforcement_learning/rl_agent_test.py @@ -3,6 +3,7 @@ from collections import namedtuple import gym import numpy as np +from torch.utils.tensorboard import SummaryWriter from reinforcement_learning.dddqn_policy import DDDQNPolicy from reinforcement_learning.ppo_agent import PPOPolicy @@ -36,6 +37,9 @@ def cartpole(use_dddqn=False): episode = 0 checkpoint_interval = 20 scores_window = deque(maxlen=100) + + writer = SummaryWriter() + while True: episode += 1 state = env.reset() @@ -79,6 +83,10 @@ def cartpole(use_dddqn=False): policy.memory)), end=" ") + writer.add_scalar("CartPole/value", tot_reward, episode) + writer.add_scalar("CartPole/smoothed_value", np.mean(scores_window), episode) + writer.flush() + if __name__ == "__main__": cartpole()