Skip to content
Snippets Groups Projects
Commit 73919673 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

minor test of multi speed implementation

parent e1775e9f
No related branches found
No related tags found
No related merge requests found
...@@ -40,21 +40,22 @@ env = RailEnv(width=15, ...@@ -40,21 +40,22 @@ env = RailEnv(width=15,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0), rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0),
number_of_agents=1) number_of_agents=1)
"""
env = RailEnv(width=10, env = RailEnv(width=10,
height=20, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) height=20, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.load("./railway/complex_scene.pkl") env.load("./railway/complex_scene.pkl")
file_load = True
""" """
env = RailEnv(width=12, env = RailEnv(width=20,
height=12, height=20,
rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=5, min_dist=10, max_dist=99999, seed=0), rail_generator=complex_rail_generator(nr_start_goal=20, nr_extra=5, min_dist=10, max_dist=99999, seed=0),
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
number_of_agents=5) number_of_agents=15)
file_load = False
env.reset(True, True) env.reset(True, True)
"""
env_renderer = RenderTool(env, gl="PILSVG") env_renderer = RenderTool(env, gl="PILSVG",)
handle = env.get_agent_handles() handle = env.get_agent_handles()
state_size = 168 * 2 state_size = 168 * 2
...@@ -78,6 +79,7 @@ agent = Agent(state_size, action_size, "FC", 0) ...@@ -78,6 +79,7 @@ agent = Agent(state_size, action_size, "FC", 0)
agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth')) agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth'))
demo = True demo = True
record_images = False
def max_lt(seq, val): def max_lt(seq, val):
""" """
...@@ -129,7 +131,10 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1): ...@@ -129,7 +131,10 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
for trials in range(1, n_trials + 1): for trials in range(1, n_trials + 1):
# Reset environment # Reset environment
obs = env.reset(True, True) if file_load :
obs = env.reset(False, False)
else:
obs = env.reset(True, True)
if demo: if demo:
env_renderer.set_new_rail() env_renderer.set_new_rail()
final_obs = obs.copy() final_obs = obs.copy()
...@@ -154,13 +159,15 @@ for trials in range(1, n_trials + 1): ...@@ -154,13 +159,15 @@ for trials in range(1, n_trials + 1):
for step in range(max_steps): for step in range(max_steps):
if demo: if demo:
env_renderer.renderEnv(show=True, show_observations=False) env_renderer.renderEnv(show=True, show_observations=False)
if record_images:
env_renderer.gl.saveImage("./Images/frame_{:04d}.bmp".format(step))
# print(step) # print(step)
# Action # Action
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
if demo: if demo:
eps = 0 eps = 0
# action = agent.act(np.array(obs[a]), eps=eps) # action = agent.act(np.array(obs[a]), eps=eps)
action = agent.act(agent_obs[a], eps=eps) action = 2 #agent.act(agent_obs[a], eps=eps)
action_prob[action] += 1 action_prob[action] += 1
action_dict.update({a: action}) action_dict.update({a: action})
# Environment step # Environment step
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment