Skip to content
Snippets Groups Projects
Commit 5e8ec90c authored by maljx's avatar maljx
Browse files

merge commit

parents ca044f83 8e79b68f
No related branches found
No related tags found
No related merge requests found
......@@ -30,9 +30,9 @@ env = RailEnv(width=10,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=1)
"""
env = RailEnv(width=50,
height=50,
rail_generator=complex_rail_generator(nr_start_goal=50, min_dist=5, max_dist=99999, seed=0),
env = RailEnv(width=15,
height=15,
rail_generator=complex_rail_generator(nr_start_goal=5, min_dist=5, max_dist=99999, seed=0),
number_of_agents=5)
"""
env = RailEnv(width=20,
......@@ -45,7 +45,7 @@ env = RailEnv(width=20,
env_renderer = RenderTool(env, gl="QT")
handle = env.get_agent_handles()
state_size = 105
state_size = 105 * 2
action_size = 4
n_trials = 15000
eps = 1.
......@@ -55,13 +55,16 @@ action_dict = dict()
final_action_dict = dict()
scores_window = deque(maxlen=100)
done_window = deque(maxlen=100)
time_obs = deque(maxlen=2)
scores = []
dones_list = []
action_prob = [0] * 4
agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, "FC", 0)
agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint1500.pth'))
# agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth'))
demo = True
demo = False
def max_lt(seq, val):
......@@ -103,11 +106,11 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
max_obs = max(1, max_lt(obs, 1000))
min_obs = max(0, min_lt(obs, 0))
if max_obs == min_obs:
return np.clip(np.array(obs)/ max_obs, clip_min, clip_max)
return np.clip(np.array(obs) / max_obs, clip_min, clip_max)
norm = np.abs(max_obs - min_obs)
if norm == 0:
norm = 1.
return np.clip((np.array(obs)-min_obs)/ norm, clip_min, clip_max)
return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)
for trials in range(1, n_trials + 1):
......@@ -115,13 +118,18 @@ for trials in range(1, n_trials + 1):
# Reset environment
obs = env.reset()
final_obs = obs.copy()
final_obs_next = obs.copy()
final_obs_next = obs.copy()
for a in range(env.get_num_agents()):
data, distance = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=5, current_depth=0)
data = norm_obs_clip(data)
distance = norm_obs_clip(distance)
obs[a] = np.concatenate((data, distance))
for i in range(2):
time_obs.append(obs)
# env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
for a in range(env.get_num_agents()):
agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
score = 0
env_done = 0
......@@ -134,7 +142,8 @@ for trials in range(1, n_trials + 1):
for a in range(env.get_num_agents()):
if demo:
eps = 0
action = agent.act(np.array(obs[a]), eps=eps)
# action = agent.act(np.array(obs[a]), eps=eps)
action = agent.act(agent_obs[a])
action_prob[action] += 1
action_dict.update({a: action})
......@@ -148,17 +157,21 @@ for trials in range(1, n_trials + 1):
distance = norm_obs_clip(distance)
next_obs[a] = np.concatenate((data, distance))
time_obs.append(next_obs)
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
agent_next_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
if done[a]:
final_obs[a] = obs[a].copy()
final_obs_next[a] = next_obs[a].copy()
final_obs[a] = agent_obs[a].copy()
final_obs_next[a] = agent_next_obs[a].copy()
final_action_dict.update({a: action_dict[a]})
if not demo and not done[a]:
agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
agent.step(agent_obs[a], action_dict[a], all_rewards[a], agent_next_obs[a], done[a])
score += all_rewards[a]
obs = next_obs.copy()
agent_obs = agent_next_obs.copy()
if done['__all__']:
env_done = 1
for a in range(env.get_num_agents()):
......
......@@ -19,3 +19,4 @@ PyQt5==5.12
Pillow==5.4.1
msgpack==0.6.1
svgutils==0.3.1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment