Skip to content
Snippets Groups Projects
Commit 3af1729d authored by hagrid67's avatar hagrid67
Browse files

fixed lint, commented out some untouched Agent / pytorch refs

parent cb7d732a
No related branches found
No related tags found
No related merge requests found
...@@ -26,9 +26,11 @@ class Player(object): ...@@ -26,9 +26,11 @@ class Player(object):
self.scores = [] self.scores = []
self.dones_list = [] self.dones_list = []
self.action_prob = [0]*4 self.action_prob = [0]*4
# Removing refs to a real agent for now.
# self.agent = Agent(self.state_size, self.action_size, "FC", 0) # self.agent = Agent(self.state_size, self.action_size, "FC", 0)
# self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth')) # self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth'))
#self.agent.qnetwork_local.load_state_dict(torch.load( # self.agent.qnetwork_local.load_state_dict(torch.load(
# '../flatland/flatland/baselines/Nets/avoid_checkpoint15000.pth')) # '../flatland/flatland/baselines/Nets/avoid_checkpoint15000.pth'))
self.iFrame = 0 self.iFrame = 0
...@@ -56,8 +58,11 @@ class Player(object): ...@@ -56,8 +58,11 @@ class Player(object):
# Pass the (stored) observation to the agent network and retrieve the action # Pass the (stored) observation to the agent network and retrieve the action
for handle in env.get_agent_handles(): for handle in env.get_agent_handles():
# Real Agent
# action = self.agent.act(np.array(self.obs[handle]), eps=self.eps) # action = self.agent.act(np.array(self.obs[handle]), eps=self.eps)
# Random actions
action = random.randint(0, 3) action = random.randint(0, 3)
# Numpy version uses single random sequence
# action = np.random.randint(0, 4, size=1) # action = np.random.randint(0, 4, size=1)
self.action_prob[action] += 1 self.action_prob[action] += 1
self.action_dict.update({handle: action}) self.action_dict.update({handle: action})
...@@ -65,7 +70,6 @@ class Player(object): ...@@ -65,7 +70,6 @@ class Player(object):
# Environment step - pass the agent actions to the environment, # Environment step - pass the agent actions to the environment,
# retrieve the response - observations, rewards, dones # retrieve the response - observations, rewards, dones
next_obs, all_rewards, done, _ = self.env.step(self.action_dict) next_obs, all_rewards, done, _ = self.env.step(self.action_dict)
next_obs = next_obs
for handle in env.get_agent_handles(): for handle in env.get_agent_handles():
norm = max(1, max_lt(next_obs[handle], np.inf)) norm = max(1, max_lt(next_obs[handle], np.inf))
...@@ -117,7 +121,7 @@ def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="QT"): ...@@ -117,7 +121,7 @@ def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="QT"):
for trials in range(1, n_trials + 1): for trials in range(1, n_trials + 1):
# Reset environment8 # Reset environment
oPlayer.reset() oPlayer.reset()
env_renderer.set_new_rail() env_renderer.set_new_rail()
...@@ -156,8 +160,6 @@ def main_old(render=True, delay=0.0): ...@@ -156,8 +160,6 @@ def main_old(render=True, delay=0.0):
env_renderer = RenderTool(env, gl="QTSVG") env_renderer = RenderTool(env, gl="QTSVG")
# env_renderer = RenderTool(env, gl="QT") # env_renderer = RenderTool(env, gl="QT")
state_size = 105
action_size = 4
n_trials = 9999 n_trials = 9999
eps = 1. eps = 1.
eps_end = 0.005 eps_end = 0.005
...@@ -168,7 +170,11 @@ def main_old(render=True, delay=0.0): ...@@ -168,7 +170,11 @@ def main_old(render=True, delay=0.0):
scores = [] scores = []
dones_list = [] dones_list = []
action_prob = [0]*4 action_prob = [0]*4
agent = Agent(state_size, action_size, "FC", 0)
# Real Agent
# state_size = 105
# action_size = 4
# agent = Agent(state_size, action_size, "FC", 0)
# agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth')) # agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth'))
def max_lt(seq, val): def max_lt(seq, val):
...@@ -188,12 +194,10 @@ def main_old(render=True, delay=0.0): ...@@ -188,12 +194,10 @@ def main_old(render=True, delay=0.0):
tStart = time.time() tStart = time.time()
for trials in range(1, n_trials + 1): for trials in range(1, n_trials + 1):
# Reset environment8 # Reset environment
obs = env.reset() obs = env.reset()
env_renderer.set_new_rail() env_renderer.set_new_rail()
#obs = obs[0]
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
norm = max(1, max_lt(obs[a], np.inf)) norm = max(1, max_lt(obs[a], np.inf))
obs[a] = np.clip(np.array(obs[a]) / norm, -1, 1) obs[a] = np.clip(np.array(obs[a]) / norm, -1, 1)
...@@ -210,13 +214,12 @@ def main_old(render=True, delay=0.0): ...@@ -210,13 +214,12 @@ def main_old(render=True, delay=0.0):
# print(step) # print(step)
# Action # Action
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
action = random.randint(0,3) # agent.act(np.array(obs[a]), eps=eps) action = random.randint(0, 3) # agent.act(np.array(obs[a]), eps=eps)
action_prob[action] += 1 action_prob[action] += 1
action_dict.update({a: action}) action_dict.update({a: action})
if render: if render:
env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step, action_dict=action_dict) env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step, action_dict=action_dict)
#time.sleep(10)
if delay > 0: if delay > 0:
time.sleep(delay) time.sleep(delay)
...@@ -224,15 +227,16 @@ def main_old(render=True, delay=0.0): ...@@ -224,15 +227,16 @@ def main_old(render=True, delay=0.0):
# Environment step # Environment step
next_obs, all_rewards, done, _ = env.step(action_dict) next_obs, all_rewards, done, _ = env.step(action_dict)
#next_obs = next_obs[0]
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
norm = max(1, max_lt(next_obs[a], np.inf)) norm = max(1, max_lt(next_obs[a], np.inf))
next_obs[a] = np.clip(np.array(next_obs[a]) / norm, -1, 1) next_obs[a] = np.clip(np.array(next_obs[a]) / norm, -1, 1)
# Update replay buffer and train agent # Update replay buffer and train agent
for a in range(env.get_num_agents()): # only needed for "real" agent
agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]) # for a in range(env.get_num_agents()):
score += all_rewards[a] # agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
# score += all_rewards[a]
obs = next_obs.copy() obs = next_obs.copy()
if done['__all__']: if done['__all__']:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment