diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index a6ee6134ffd5bd8006fe5fbc0d0ace76e4d13511..0e5ad18128c115cd86b844f1eb7f0489947e8c37 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -43,7 +43,7 @@ env = RailEnv(width=15, """ env = RailEnv(width=10, height=20, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) -env.load("./railway/flatland.pkl") +env.load("./railway/complex_scene.pkl") file_load = True """ @@ -80,7 +80,7 @@ agent = Agent(state_size, action_size, "FC", 0) agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth')) demo = True -record_images = True +record_images = False @@ -129,15 +129,15 @@ for trials in range(1, n_trials + 1): if demo: eps = 0 # action = agent.act(np.array(obs[a]), eps=eps) - action = 2 #agent.act(agent_obs[a], eps=eps) + action = agent.act(agent_obs[a], eps=eps) action_prob[action] += 1 action_dict.update({a: action}) # Environment step next_obs, all_rewards, done, _ = env.step(action_dict) for a in range(env.get_num_agents()): - data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=8, - current_depth=0) + data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), num_features_per_node=8, + current_depth=0) data = norm_obs_clip(data) distance = norm_obs_clip(distance) agent_data = np.clip(agent_data, -1, 1)