diff --git a/examples/play_model.py b/examples/play_model.py index 5a33c12610e158fb6b35e8ec84405907d4e83b7b..6e3335ed41f3af61c0e26e5b5c391d84cf0f6550 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -1,6 +1,7 @@ from flatland.envs.rail_env import RailEnv, random_rail_generator # from flatland.core.env_observation_builder import TreeObsForRailEnv from flatland.utils.rendertools import RenderTool +from flatland.utils.render_qt import QtRailRender from flatland.baselines.dueling_double_dqn import Agent from collections import deque import torch @@ -31,7 +32,8 @@ def main(): height=7, rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), number_of_agents=1) - env_renderer = RenderTool(env) + # env_renderer = RenderTool(env) + env_renderer = QtRailRender(env) plt.figure(figsize=(5,5)) # fRedis = redis.Redis() @@ -101,7 +103,6 @@ def main(): score += all_rewards[a] env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step) - sEnv = fR.set("RailEnv0") obs = next_obs.copy() if done['__all__']: