diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 2a4af22ad148814f397dd32b7e96f3f6d666c70f..1d12b53ab7d1ee248601c6aeaa782b75c6b5b15d 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -43,13 +43,18 @@ env = RailEnv(width=15, env = RailEnv(width=10, height=20) env.load("./railway/complex_scene.pkl") + +env = RailEnv(width=15, + height=15, + rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0), + number_of_agents=1) env.reset(False, False) env_renderer = RenderTool(env, gl="PILSVG") handle = env.get_agent_handles() -state_size = 105 * 2 -action_size = 4 +state_size = 147 * 2 +action_size = 5 n_trials = 15000 eps = 1. eps_end = 0.005 @@ -61,13 +66,13 @@ done_window = deque(maxlen=100) time_obs = deque(maxlen=2) scores = [] dones_list = [] -action_prob = [0] * 4 +action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint10400.pth')) +#agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth')) -demo = True +demo = False def max_lt(seq, val): @@ -119,18 +124,18 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1): for trials in range(1, n_trials + 1): # Reset environment - obs = env.reset(False,False) - + obs = env.reset(True, True) + if demo: + env_renderer.set_new_rail() final_obs = obs.copy() final_obs_next = obs.copy() for a in range(env.get_num_agents()): - data, distance = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=5, current_depth=0) - + data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=7, current_depth=0) data = norm_obs_clip(data) distance = norm_obs_clip(distance) - obs[a] = np.concatenate((data, distance)) + obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) for i in range(2): time_obs.append(obs) # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) @@ -142,25 +147,26 @@ for trials in range(1, n_trials + 1): # Run episode for step in range(360): if demo: + env_renderer.renderEnv(show=True,show_observations=False) # print(step) # Action for a in range(env.get_num_agents()): if demo: - eps = 0 + eps = 1 # action = agent.act(np.array(obs[a]), eps=eps) - action = agent.act(agent_obs[a]) + action = agent.act(agent_obs[a], eps=eps) action_prob[action] += 1 action_dict.update({a: action}) # Environment step next_obs, all_rewards, done, _ = env.step(action_dict) for a in range(env.get_num_agents()): - data, distance = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5, + data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=7, current_depth=0) data = norm_obs_clip(data) distance = norm_obs_clip(distance) - next_obs[a] = np.concatenate((data, distance)) + next_obs[a] = np.concatenate((np.concatenate((data, distance)),agent_data)) time_obs.append(next_obs)