From b42ef74e64bf32df121196208b5bbc95e51994ec Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Tue, 23 Apr 2019 16:39:10 +0200 Subject: [PATCH] updated training for navigation --- examples/training_navigation.py | 36 +++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/examples/training_navigation.py b/examples/training_navigation.py index eddb907..231d4e9 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -20,8 +20,8 @@ transition_probability = [10.0, # empty cell - Case 0 0.0] # Case 7 - dead end # Example generate a random rail -env = RailEnv(width=7, - height=7, +env = RailEnv(width=5, + height=5, rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), number_of_agents=1) env_renderer = RenderTool(env) @@ -29,7 +29,7 @@ handle = env.get_agent_handles() state_size = 105 action_size = 4 -n_trials = 5000 +n_trials = 9999 eps = 1. eps_end = 0.005 eps_decay = 0.998 @@ -40,14 +40,27 @@ scores = [] dones_list = [] action_prob = [0]*4 agent = Agent(state_size, action_size, "FC", 0) +agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint8000.pth')) +def max_lt(seq, val): + """ + Return greatest item in seq for which item < val applies. + None is returned if seq was empty or all items in seq were >= val. + """ + + idx = len(seq)-1 + while idx >= 0: + if seq[idx] < val and seq[idx] > 0: + return seq[idx] + idx -= 1 + return None for trials in range(1, n_trials + 1): # Reset environment obs = env.reset() for a in range(env.number_of_agents): - if np.max(obs[a]) > 0 and np.max(obs[a]) < np.inf: - obs[a] = np.clip(obs[a] / np.max(obs[a]), -1, 1) + norm = max(1, max_lt(obs[a],np.inf)) + obs[a] = np.clip(np.array(obs[a]) / norm, -1, 1) # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) @@ -55,21 +68,21 @@ for trials in range(1, n_trials + 1): env_done = 0 # Run episode - for step in range(100): + for step in range(50): #if trials > 114: - # env_renderer.renderEnv(show=True) - + #env_renderer.renderEnv(show=True) + #print(step) # Action for a in range(env.number_of_agents): - action = agent.act(np.array(obs[a]), eps=eps) + action = agent.act(np.array(obs[a]), eps=0) action_prob[action] += 1 action_dict.update({a: action}) # Environment step next_obs, all_rewards, done, _ = env.step(action_dict) for a in range(env.number_of_agents): - if np.max(next_obs[a]) > 0 and np.max(next_obs[a]) < np.inf: - next_obs[a] = np.clip(next_obs[a] / np.max(next_obs[a]), -1, 1) + norm = max(1, max_lt(next_obs[a], np.inf)) + next_obs[a] = np.clip(np.array(next_obs[a]) / norm, -1, 1) # Update replay buffer and train agent for a in range(env.number_of_agents): agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]) @@ -108,3 +121,4 @@ for trials in range(1, n_trials + 1): eps, action_prob / np.sum(action_prob))) torch.save(agent.qnetwork_local.state_dict(), '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth') + -- GitLab