Skip to content
Snippets Groups Projects
Commit 73919673 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

minor test of multi speed implementation

parent e1775e9f
No related branches found
No related tags found
No related merge requests found
......@@ -40,21 +40,22 @@ env = RailEnv(width=15,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0),
number_of_agents=1)
"""
env = RailEnv(width=10,
height=20, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.load("./railway/complex_scene.pkl")
file_load = True
"""
env = RailEnv(width=12,
height=12,
rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=5, min_dist=10, max_dist=99999, seed=0),
env = RailEnv(width=20,
height=20,
rail_generator=complex_rail_generator(nr_start_goal=20, nr_extra=5, min_dist=10, max_dist=99999, seed=0),
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
number_of_agents=5)
number_of_agents=15)
file_load = False
env.reset(True, True)
env_renderer = RenderTool(env, gl="PILSVG")
"""
env_renderer = RenderTool(env, gl="PILSVG",)
handle = env.get_agent_handles()
state_size = 168 * 2
......@@ -78,6 +79,7 @@ agent = Agent(state_size, action_size, "FC", 0)
agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth'))
demo = True
record_images = False
def max_lt(seq, val):
"""
......@@ -129,7 +131,10 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
for trials in range(1, n_trials + 1):
# Reset environment
obs = env.reset(True, True)
if file_load :
obs = env.reset(False, False)
else:
obs = env.reset(True, True)
if demo:
env_renderer.set_new_rail()
final_obs = obs.copy()
......@@ -154,13 +159,15 @@ for trials in range(1, n_trials + 1):
for step in range(max_steps):
if demo:
env_renderer.renderEnv(show=True, show_observations=False)
if record_images:
env_renderer.gl.saveImage("./Images/frame_{:04d}.bmp".format(step))
# print(step)
# Action
for a in range(env.get_num_agents()):
if demo:
eps = 0
# action = agent.act(np.array(obs[a]), eps=eps)
action = agent.act(agent_obs[a], eps=eps)
action = 2 #agent.act(agent_obs[a], eps=eps)
action_prob[action] += 1
action_dict.update({a: action})
# Environment step
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment