Skip to content
Snippets Groups Projects
Commit d5376a20 authored by gmollard's avatar gmollard
Browse files

Merge branch 'master' of gitlab.aicrowd.com:flatland/baselines

parents 38d81491 7d37fb3a
No related branches found
No related tags found
No related merge requests found
...@@ -43,13 +43,18 @@ env = RailEnv(width=15, ...@@ -43,13 +43,18 @@ env = RailEnv(width=15,
env = RailEnv(width=10, env = RailEnv(width=10,
height=20) height=20)
env.load("./railway/complex_scene.pkl") env.load("./railway/complex_scene.pkl")
env = RailEnv(width=15,
height=15,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0),
number_of_agents=1)
env.reset(False, False) env.reset(False, False)
env_renderer = RenderTool(env, gl="PILSVG") env_renderer = RenderTool(env, gl="PILSVG")
handle = env.get_agent_handles() handle = env.get_agent_handles()
state_size = 105 * 2 state_size = 147 * 2
action_size = 4 action_size = 5
n_trials = 15000 n_trials = 15000
eps = 1. eps = 1.
eps_end = 0.005 eps_end = 0.005
...@@ -61,13 +66,13 @@ done_window = deque(maxlen=100) ...@@ -61,13 +66,13 @@ done_window = deque(maxlen=100)
time_obs = deque(maxlen=2) time_obs = deque(maxlen=2)
scores = [] scores = []
dones_list = [] dones_list = []
action_prob = [0] * 4 action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents() agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, "FC", 0) agent = Agent(state_size, action_size, "FC", 0)
agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint10400.pth')) #agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth'))
demo = True demo = False
def max_lt(seq, val): def max_lt(seq, val):
...@@ -119,18 +124,18 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1): ...@@ -119,18 +124,18 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
for trials in range(1, n_trials + 1): for trials in range(1, n_trials + 1):
# Reset environment # Reset environment
obs = env.reset(False,False) obs = env.reset(True, True)
if demo:
env_renderer.set_new_rail()
final_obs = obs.copy() final_obs = obs.copy()
final_obs_next = obs.copy() final_obs_next = obs.copy()
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
data, distance = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=5, current_depth=0) data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=7, current_depth=0)
data = norm_obs_clip(data) data = norm_obs_clip(data)
distance = norm_obs_clip(distance) distance = norm_obs_clip(distance)
obs[a] = np.concatenate((data, distance))
obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
for i in range(2): for i in range(2):
time_obs.append(obs) time_obs.append(obs)
# env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
...@@ -142,25 +147,26 @@ for trials in range(1, n_trials + 1): ...@@ -142,25 +147,26 @@ for trials in range(1, n_trials + 1):
# Run episode # Run episode
for step in range(360): for step in range(360):
if demo: if demo:
env_renderer.renderEnv(show=True,show_observations=False) env_renderer.renderEnv(show=True,show_observations=False)
# print(step) # print(step)
# Action # Action
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
if demo: if demo:
eps = 0 eps = 1
# action = agent.act(np.array(obs[a]), eps=eps) # action = agent.act(np.array(obs[a]), eps=eps)
action = agent.act(agent_obs[a]) action = agent.act(agent_obs[a], eps=eps)
action_prob[action] += 1 action_prob[action] += 1
action_dict.update({a: action}) action_dict.update({a: action})
# Environment step # Environment step
next_obs, all_rewards, done, _ = env.step(action_dict) next_obs, all_rewards, done, _ = env.step(action_dict)
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
data, distance = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5, data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=7,
current_depth=0) current_depth=0)
data = norm_obs_clip(data) data = norm_obs_clip(data)
distance = norm_obs_clip(distance) distance = norm_obs_clip(distance)
next_obs[a] = np.concatenate((data, distance)) next_obs[a] = np.concatenate((np.concatenate((data, distance)),agent_data))
time_obs.append(next_obs) time_obs.append(next_obs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment