diff --git a/torch_training/railway/complex_scene_2.pkl b/torch_training/railway/complex_scene_2.pkl new file mode 100644 index 0000000000000000000000000000000000000000..5ebb8ab1715b3c9f01171116ac519128f4d67234 Binary files /dev/null and b/torch_training/railway/complex_scene_2.pkl differ diff --git a/torch_training/railway/complex_scene_3.pkl b/torch_training/railway/complex_scene_3.pkl new file mode 100644 index 0000000000000000000000000000000000000000..fddf5794b7292065e0a7e4bf72e616d6087ee179 Binary files /dev/null and b/torch_training/railway/complex_scene_3.pkl differ diff --git a/torch_training/railway/flatland.pkl b/torch_training/railway/flatland.pkl new file mode 100644 index 0000000000000000000000000000000000000000..a652becf420b1ec592c9de07681a394085a08bc8 Binary files /dev/null and b/torch_training/railway/flatland.pkl differ diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index de4b792ba150315d2c83e79f620407526ae6f215..d085a8b3d1a4d21408afa4e88b57fb92f3a1bab9 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -43,7 +43,7 @@ env = RailEnv(width=15, """ env = RailEnv(width=10, height=20, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) -env.load("./railway/complex_scene.pkl") +env.load("./railway/flatland.pkl") file_load = True """ @@ -79,7 +79,7 @@ agent = Agent(state_size, action_size, "FC", 0) agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth')) demo = True -record_images = False +record_images = True def max_lt(seq, val): """ @@ -140,6 +140,7 @@ for trials in range(1, n_trials + 1): final_obs = obs.copy() final_obs_next = obs.copy() for a in range(env.get_num_agents()): + print(a) data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=8, current_depth=0) data = norm_obs_clip(data) @@ -164,14 +165,14 @@ for trials in range(1, n_trials + 1): if demo: env_renderer.renderEnv(show=True, show_observations=False) if record_images: - env_renderer.gl.saveImage("./Images/frame_{:04d}.bmp".format(step)) + env_renderer.gl.saveImage("./Images/flatland_frame_{:04d}.bmp".format(step)) # print(step) # Action for a in range(env.get_num_agents()): if demo: eps = 0 # action = agent.act(np.array(obs[a]), eps=eps) - action = agent.act(agent_obs[a], eps=eps) + action = 2 #agent.act(agent_obs[a], eps=eps) action_prob[action] += 1 action_dict.update({a: action}) # Environment step