diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index c0433720823d2b9f7ecc6851bd0df9eaab5f8691..634db690ec269d37aea0bcdb309095e6dae1ec23 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -40,21 +40,22 @@ env = RailEnv(width=15, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0), number_of_agents=1) - +""" env = RailEnv(width=10, height=20, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.load("./railway/complex_scene.pkl") +file_load = True """ -env = RailEnv(width=12, - height=12, - rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=5, min_dist=10, max_dist=99999, seed=0), +env = RailEnv(width=20, + height=20, + rail_generator=complex_rail_generator(nr_start_goal=20, nr_extra=5, min_dist=10, max_dist=99999, seed=0), obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - number_of_agents=5) - + number_of_agents=15) +file_load = False env.reset(True, True) - -env_renderer = RenderTool(env, gl="PILSVG") +""" +env_renderer = RenderTool(env, gl="PILSVG",) handle = env.get_agent_handles() state_size = 168 * 2 @@ -78,6 +79,7 @@ agent = Agent(state_size, action_size, "FC", 0) agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth')) demo = True +record_images = False def max_lt(seq, val): """ @@ -129,7 +131,10 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1): for trials in range(1, n_trials + 1): # Reset environment - obs = env.reset(True, True) + if file_load : + obs = env.reset(False, False) + else: + obs = env.reset(True, True) if demo: env_renderer.set_new_rail() final_obs = obs.copy() @@ -154,13 +159,15 @@ for trials in range(1, n_trials + 1): for step in range(max_steps): if demo: env_renderer.renderEnv(show=True, show_observations=False) + if record_images: + env_renderer.gl.saveImage("./Images/frame_{:04d}.bmp".format(step)) # print(step) # Action for a in range(env.get_num_agents()): if demo: eps = 0 # action = agent.act(np.array(obs[a]), eps=eps) - action = agent.act(agent_obs[a], eps=eps) + action = 2 #agent.act(agent_obs[a], eps=eps) action_prob[action] += 1 action_dict.update({a: action}) # Environment step