diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 1d12b53ab7d1ee248601c6aeaa782b75c6b5b15d..1720d7bc67b55f8e4a8a5a88501cdad2368b603f 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -44,11 +44,11 @@ env = RailEnv(width=10, height=20) env.load("./railway/complex_scene.pkl") -env = RailEnv(width=15, - height=15, - rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0), - number_of_agents=1) -env.reset(False, False) +env = RailEnv(width=8, + height=8, + rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=5, min_dist=5, max_dist=99999, seed=0), + number_of_agents=2) +env.reset(True, True) env_renderer = RenderTool(env, gl="PILSVG") handle = env.get_agent_handles() @@ -70,11 +70,10 @@ action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -#agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth')) +#agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint1500.pth')) demo = False - def max_lt(seq, val): """ Return greatest item in seq for which item < val applies. @@ -145,35 +144,34 @@ for trials in range(1, n_trials + 1): score = 0 env_done = 0 # Run episode - for step in range(360): + for step in range(100): if demo: - env_renderer.renderEnv(show=True,show_observations=False) + env_renderer.renderEnv(show=True, show_observations=True) # print(step) # Action for a in range(env.get_num_agents()): if demo: - eps = 1 + eps = 0 # action = agent.act(np.array(obs[a]), eps=eps) action = agent.act(agent_obs[a], eps=eps) action_prob[action] += 1 action_dict.update({a: action}) - # Environment step + next_obs, all_rewards, done, _ = env.step(action_dict) for a in range(env.get_num_agents()): data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=7, current_depth=0) data = norm_obs_clip(data) distance = norm_obs_clip(distance) - next_obs[a] = np.concatenate((np.concatenate((data, distance)),agent_data)) + next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) time_obs.append(next_obs) # Update replay buffer and train agent for a in range(env.get_num_agents()): agent_next_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a])) - if done[a]: final_obs[a] = agent_obs[a].copy() final_obs_next[a] = agent_next_obs[a].copy() @@ -214,4 +212,4 @@ for trials in range(1, n_trials + 1): action_prob / np.sum(action_prob))) torch.save(agent.qnetwork_local.state_dict(), './Nets/avoid_checkpoint' + str(trials) + '.pth') - action_prob = [1] * 4 + action_prob = [1] * action_size