diff --git a/examples/play_model.py b/examples/play_model.py index 713d831e4936b0803f7d69f4eae8253db0685ef3..5a33c12610e158fb6b35e8ec84405907d4e83b7b 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -7,6 +7,7 @@ import torch import random import numpy as np import matplotlib.pyplot as plt +import redis def main(): @@ -32,6 +33,8 @@ def main(): number_of_agents=1) env_renderer = RenderTool(env) plt.figure(figsize=(5,5)) + # fRedis = redis.Redis() + handle = env.get_agent_handles() state_size = 105 @@ -98,6 +101,7 @@ def main(): score += all_rewards[a] env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step) + sEnv = fR.set("RailEnv0") obs = next_obs.copy() if done['__all__']: @@ -115,10 +119,8 @@ def main(): '\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( env.number_of_agents, trials, - np.mean( - scores_window), - 100 * np.mean( - done_window), + np.mean(scores_window), + 100 * np.mean(done_window), eps, action_prob/np.sum(action_prob)), end=" ") if trials % 100 == 0: