diff --git a/checkpoints/ppo/model_checkpoint.meta b/checkpoints/ppo/model_checkpoint.meta index 7617876cf3d7031f066a779fde687404b0a1cc6f..b45d9b8937c572df6febfa2f0ac5a9d4cda4eb0e 100644 Binary files a/checkpoints/ppo/model_checkpoint.meta and b/checkpoints/ppo/model_checkpoint.meta differ diff --git a/checkpoints/ppo/model_checkpoint.optimizer b/checkpoints/ppo/model_checkpoint.optimizer index b93d28155360433cbb1574b2e797bf1e293c2f6c..190ef25976343f4c1cca9b751f78fc8fdcadfa28 100644 Binary files a/checkpoints/ppo/model_checkpoint.optimizer and b/checkpoints/ppo/model_checkpoint.optimizer differ diff --git a/checkpoints/ppo/model_checkpoint.policy b/checkpoints/ppo/model_checkpoint.policy index bc21bc40897b530a65966b1cbbbaeb41835f7b69..c4492df60aaec91709c87ae729bf71480866b31e 100644 Binary files a/checkpoints/ppo/model_checkpoint.policy and b/checkpoints/ppo/model_checkpoint.policy differ diff --git a/src/extra.py b/src/extra.py index 84f20d9def1cad022f90abacadd646ecb34b04dc..025c8f1355f5c4c1122a3b5272049234e53c71da 100644 --- a/src/extra.py +++ b/src/extra.py @@ -67,19 +67,9 @@ class Extra(ObservationBuilder): def __init__(self, max_depth): self.max_depth = max_depth - self.observation_dim = 22 + self.observation_dim = 30 self.agent = None - def loadAgent(self): - if self.agent is not None: - return - self.state_size = self.env.obs_builder.observation_dim - self.action_size = 5 - print("action_size: ", self.action_size) - print("state_size: ", self.state_size) - self.agent = Agent(self.state_size, self.action_size, 0) - self.agent.load('./checkpoints/', 0, 1.0) - def build_data(self): if self.env is not None: self.env.dev_obs_dict = {} @@ -197,6 +187,9 @@ class Extra(ObservationBuilder): def normalize_observation(self, obsData): return obsData + def is_collision(self, obsData): + return False + def reset(self): self.build_data() return @@ -210,15 +203,16 @@ class Extra(ObservationBuilder): return 2 return 3 - def _explore(self, handle, new_position, new_direction, depth=0): - + def _explore(self, handle, distance_map, new_position, new_direction, depth=0): has_opp_agent = 0 has_same_agent = 0 visited = [] + visited_direction = [] + visited_min_distance = np.inf # stop exploring (max_depth reached) if depth >= self.max_depth: - return has_opp_agent, has_same_agent, visited + return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance # max_explore_steps = 100 cnt = 0 @@ -226,15 +220,22 @@ class Extra(ObservationBuilder): cnt += 1 visited.append(new_position) + visited_direction.append(new_direction) + + new_cell_dist = distance_map[handle, + new_position[0], new_position[1], + new_direction] + visited_min_distance = min(visited_min_distance, new_cell_dist) + opp_a = self.env.agent_positions[new_position] if opp_a != -1 and opp_a != handle: if self.env.agents[opp_a].direction != new_direction: # opp agent found has_opp_agent = 1 - return has_opp_agent, has_same_agent, visited + return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance else: has_same_agent = 1 - return has_opp_agent, has_same_agent, visited + return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance # convert one-hot encoding to 0,1,2,3 possible_transitions = self.env.rail.get_transitions(*new_position, new_direction) @@ -243,20 +244,28 @@ class Extra(ObservationBuilder): agents_near_to_switch_all = \ self.check_agent_descision(new_position, new_direction) if agents_near_to_switch: - return has_opp_agent, has_same_agent, visited + return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance if agents_on_switch: for dir_loop in range(4): if possible_transitions[dir_loop] == 1: - hoa, hsa, v = self._explore(handle, new_position, new_direction, depth + 1) - visited.append(v) + hoa, hsa, v, d, min_dist = self._explore(handle, + distance_map, + get_new_position(new_position, dir_loop), + dir_loop, + depth + 1) + if np.math.isinf(min_dist) == False: + visited_min_distance = min(visited_min_distance, min_dist) + + visited = visited + v + visited_direction = visited_direction + d has_opp_agent = 0.5 * (has_opp_agent + hoa) has_same_agent = 0.5 * (has_same_agent + hsa) - return has_opp_agent, has_same_agent, visited + return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance else: new_direction = fast_argmax(possible_transitions) new_position = get_new_position(new_position, new_direction) - return has_opp_agent, has_same_agent, visited + return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance def get(self, handle): # all values are [0,1] @@ -285,6 +294,7 @@ class Extra(ObservationBuilder): observation = np.zeros(self.observation_dim) visited = [] + visited_direction = [] agent = self.env.agents[handle] agent_done = False @@ -301,6 +311,7 @@ class Extra(ObservationBuilder): if not agent_done: visited.append(agent_virtual_position) + visited_direction.append(agent.direction) distance_map = self.env.distance_map.get() current_cell_dist = distance_map[handle, agent_virtual_position[0], agent_virtual_position[1], @@ -319,8 +330,12 @@ class Extra(ObservationBuilder): if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)): observation[dir_loop] = int(new_cell_dist < current_cell_dist) - has_opp_agent, has_same_agent, v = self._explore(handle, new_position, branch_direction) - visited.append(v) + has_opp_agent, has_same_agent, vis, dir, min_dist = self._explore(handle, + distance_map, + new_position, + branch_direction) + visited = visited + vis + visited_direction = visited_direction + dir observation[10 + dir_loop] = 1 observation[14 + dir_loop] = has_opp_agent @@ -334,6 +349,16 @@ class Extra(ObservationBuilder): observation[8] = int(agents_near_to_switch) observation[9] = int(agents_near_to_switch_all) + observation[22] = int(self.env._elapsed_steps % 4 == 0) + observation[23] = int(self.env._elapsed_steps % 4 == 1) + observation[24] = int(self.env._elapsed_steps % 4 == 2) + observation[25] = int(self.env._elapsed_steps % 4 == 3) + + observation[26] = int(agent.direction % 4 == 0) + observation[27] = int(agent.direction % 4 == 1) + observation[28] = int(agent.direction % 4 == 2) + observation[29] = int(agent.direction % 4 == 3) + self.env.dev_obs_dict.update({handle: visited}) return observation @@ -349,3 +374,13 @@ class Extra(ObservationBuilder): action_dict[a] = RailEnvActions.DO_NOTHING return action_dict + + def loadAgent(self): + if self.agent is not None: + return + self.state_size = self.env.obs_builder.observation_dim + self.action_size = 5 + print("action_size: ", self.action_size) + print("state_size: ", self.state_size) + self.agent = Agent(self.state_size, self.action_size, 0) + self.agent.load('./checkpoints/', 0, 1.0) \ No newline at end of file diff --git a/src/ppo/model.py b/src/ppo/model.py index 51b86ff16691c03f6a754405352bb4cf48e4b914..421423df6739bbc4b4ed94487de7e3dfa9d973a8 100644 --- a/src/ppo/model.py +++ b/src/ppo/model.py @@ -3,7 +3,7 @@ import torch.nn.functional as F class PolicyNetwork(nn.Module): - def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128, hidsize3=32): + def __init__(self, state_size, action_size, hidsize1=128, hidsize2=256, hidsize3=32): super().__init__() self.fc1 = nn.Linear(state_size, hidsize1) self.fc2 = nn.Linear(hidsize1, hidsize2)