Forked from
Flatland / Flatland
2091 commits behind the upstream repository.
play_model.py 3.78 KiB
import random
import time
from collections import deque
import numpy as np
from flatland.envs.generators import complex_rail_generator
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
class Player(object):
def __init__(self, env):
self.env = env
self.handle = env.get_agent_handles()
self.state_size = 105
self.action_size = 4
self.n_trials = 9999
self.eps = 1.
self.eps_end = 0.005
self.eps_decay = 0.998
self.action_dict = dict()
self.scores_window = deque(maxlen=100)
self.done_window = deque(maxlen=100)
self.scores = []
self.dones_list = []
self.action_prob = [0] * 4
# Removing refs to a real agent for now.
self.iFrame = 0
self.tStart = time.time()
# Reset environment
self.env.obs_builder.reset()
self.obs = self.env._get_observations()
for envAgent in range(self.env.get_num_agents()):
norm = max(1, max_lt(self.obs[envAgent], np.inf))
self.obs[envAgent] = np.clip(np.array(self.obs[envAgent]) / norm, -1, 1)
self.score = 0
self.env_done = 0
def reset(self):
self.obs = self.env.reset()
return self.obs
def step(self):
env = self.env
# Pass the (stored) observation to the agent network and retrieve the action
for handle in env.get_agent_handles():
# Random actions
action = np.random.choice([0, 1, 2, 3], 1, p=[0.2, 0.1, 0.6, 0.1])[0]
# Numpy version uses single random sequence
self.action_prob[action] += 1
self.action_dict.update({handle: action})
# Environment step - pass the agent actions to the environment,
# retrieve the response - observations, rewards, dones
next_obs, all_rewards, done, _ = self.env.step(self.action_dict)
for handle in env.get_agent_handles():
norm = max(1, max_lt(next_obs[handle], np.inf))
next_obs[handle] = np.clip(np.array(next_obs[handle]) / norm, -1, 1)
# Update replay buffer and train agent
if False:
for handle in self.env.get_agent_handles():
self.agent.step(self.obs[handle], self.action_dict[handle],
all_rewards[handle], next_obs[handle], done[handle],
train=False)
self.score += all_rewards[handle]
self.iFrame += 1
self.obs = next_obs.copy()
if done['__all__']:
self.env_done = 1
def max_lt(seq, val):
"""
Return greatest item in seq for which item < val applies.
None is returned if seq was empty or all items in seq were >= val.
"""
idx = len(seq) - 1
while idx >= 0:
if seq[idx] < val and seq[idx] >= 0:
return seq[idx]
idx -= 1
return None
def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="PILSVG"):
random.seed(1)
np.random.seed(1)
# Example generate a random rail
env = RailEnv(width=15, height=15,
rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12),
number_of_agents=5)
if render:
env_renderer = RenderTool(env, gl=sGL)
oPlayer = Player(env)
for trials in range(1, n_trials + 1):
# Reset environment
oPlayer.reset()
env_renderer.set_new_rail()
# Run episode
for step in range(n_steps):
oPlayer.step()
if render:
env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step)
if delay > 0:
time.sleep(delay)
env_renderer.gl.close_window()
if __name__ == "__main__":
main(render=True, delay=0)