diff --git a/examples/training_navigation.py b/examples/training_navigation.py index b1032511ec00cbf23a4cfe7b8bca4bca370f5180..0e0f43964c8e6fbb9bba9325cae3b29f6c32f0e1 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -13,15 +13,15 @@ np.random.seed(1) transition_probability = [0.5, # empty cell - Case 0 1.0, # Case 1 - straight 1.0, # Case 2 - simple switch - 0.3, # Case 3 - diamond drossing + 0.3, # Case 3 - diamond crossing 0.5, # Case 4 - single slip 0.5, # Case 5 - double slip 0.2, # Case 6 - symmetrical 0.0] # Case 7 - dead end # Example generate a random rail -env = RailEnv(width=7, - height=7, +env = RailEnv(width=20, + height=20, rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), number_of_agents=1) env_renderer = RenderTool(env) @@ -29,7 +29,7 @@ handle = env.get_agent_handles() state_size = 105 action_size = 4 -n_trials = 9999 +n_trials = 15000 eps = 1. eps_end = 0.005 eps_decay = 0.998 @@ -40,19 +40,34 @@ scores = [] dones_list = [] action_prob = [0]*4 agent = Agent(state_size, action_size, "FC", 0) -agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth')) +agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth')) + +demo = True def max_lt(seq, val): """ Return greatest item in seq for which item < val applies. None is returned if seq was empty or all items in seq were >= val. """ + max = 0 + idx = len(seq)-1 + while idx >= 0: + if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max: + max = seq[idx] + idx -= 1 + return max +def min_lt(seq, val): + """ + Return smallest item in seq for which item > val applies. + None is returned if seq was empty or all items in seq were >= val. + """ + min = np.inf idx = len(seq)-1 while idx >= 0: - if seq[idx] < val and seq[idx] >= 0: - return seq[idx] + if seq[idx] > val and seq[idx] < min: + min = seq[idx] idx -= 1 - return None + return min for trials in range(1, n_trials + 1): @@ -69,12 +84,14 @@ for trials in range(1, n_trials + 1): # Run episode for step in range(50): - #if trials > 114: - env_renderer.renderEnv(show=True) + if demo: + env_renderer.renderEnv(show=True) #print(step) # Action for a in range(env.number_of_agents): - action = agent.act(np.array(obs[a]), eps=0) + if demo: + eps = 0 + action = agent.act(np.array(obs[a]), eps=eps) action_prob[action] += 1 action_dict.update({a: action}) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index cf206033d7aab65571b18de0c95d729f5d09c65c..9fd85855b094b07b2c2f643c3e39e67de6ff6a32 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -649,7 +649,8 @@ class RailEnv(Environment): # if agent is not in target position, add step penalty if self.agents_position[i][0] == self.agents_target[i][0] and \ - self.agents_position[i][1] == self.agents_target[i][1]: + self.agents_position[i][1] == self.agents_target[i][1] and \ + action_dict[handle] == 0: self.dones[handle] = True else: self.rewards_dict[handle] += step_penalty