Skip to content
Snippets Groups Projects
Commit 662a63bb authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

Test 2

parent a8275161
No related branches found
Tags submission-v0.9
No related merge requests found
No preview for this file type
No preview for this file type
No preview for this file type
...@@ -67,19 +67,9 @@ class Extra(ObservationBuilder): ...@@ -67,19 +67,9 @@ class Extra(ObservationBuilder):
def __init__(self, max_depth): def __init__(self, max_depth):
self.max_depth = max_depth self.max_depth = max_depth
self.observation_dim = 22 self.observation_dim = 30
self.agent = None self.agent = None
def loadAgent(self):
if self.agent is not None:
return
self.state_size = self.env.obs_builder.observation_dim
self.action_size = 5
print("action_size: ", self.action_size)
print("state_size: ", self.state_size)
self.agent = Agent(self.state_size, self.action_size, 0)
self.agent.load('./checkpoints/', 0, 1.0)
def build_data(self): def build_data(self):
if self.env is not None: if self.env is not None:
self.env.dev_obs_dict = {} self.env.dev_obs_dict = {}
...@@ -197,6 +187,9 @@ class Extra(ObservationBuilder): ...@@ -197,6 +187,9 @@ class Extra(ObservationBuilder):
def normalize_observation(self, obsData): def normalize_observation(self, obsData):
return obsData return obsData
def is_collision(self, obsData):
return False
def reset(self): def reset(self):
self.build_data() self.build_data()
return return
...@@ -210,15 +203,16 @@ class Extra(ObservationBuilder): ...@@ -210,15 +203,16 @@ class Extra(ObservationBuilder):
return 2 return 2
return 3 return 3
def _explore(self, handle, new_position, new_direction, depth=0): def _explore(self, handle, distance_map, new_position, new_direction, depth=0):
has_opp_agent = 0 has_opp_agent = 0
has_same_agent = 0 has_same_agent = 0
visited = [] visited = []
visited_direction = []
visited_min_distance = np.inf
# stop exploring (max_depth reached) # stop exploring (max_depth reached)
if depth >= self.max_depth: if depth >= self.max_depth:
return has_opp_agent, has_same_agent, visited return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance
# max_explore_steps = 100 # max_explore_steps = 100
cnt = 0 cnt = 0
...@@ -226,15 +220,22 @@ class Extra(ObservationBuilder): ...@@ -226,15 +220,22 @@ class Extra(ObservationBuilder):
cnt += 1 cnt += 1
visited.append(new_position) visited.append(new_position)
visited_direction.append(new_direction)
new_cell_dist = distance_map[handle,
new_position[0], new_position[1],
new_direction]
visited_min_distance = min(visited_min_distance, new_cell_dist)
opp_a = self.env.agent_positions[new_position] opp_a = self.env.agent_positions[new_position]
if opp_a != -1 and opp_a != handle: if opp_a != -1 and opp_a != handle:
if self.env.agents[opp_a].direction != new_direction: if self.env.agents[opp_a].direction != new_direction:
# opp agent found # opp agent found
has_opp_agent = 1 has_opp_agent = 1
return has_opp_agent, has_same_agent, visited return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance
else: else:
has_same_agent = 1 has_same_agent = 1
return has_opp_agent, has_same_agent, visited return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance
# convert one-hot encoding to 0,1,2,3 # convert one-hot encoding to 0,1,2,3
possible_transitions = self.env.rail.get_transitions(*new_position, new_direction) possible_transitions = self.env.rail.get_transitions(*new_position, new_direction)
...@@ -243,20 +244,28 @@ class Extra(ObservationBuilder): ...@@ -243,20 +244,28 @@ class Extra(ObservationBuilder):
agents_near_to_switch_all = \ agents_near_to_switch_all = \
self.check_agent_descision(new_position, new_direction) self.check_agent_descision(new_position, new_direction)
if agents_near_to_switch: if agents_near_to_switch:
return has_opp_agent, has_same_agent, visited return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance
if agents_on_switch: if agents_on_switch:
for dir_loop in range(4): for dir_loop in range(4):
if possible_transitions[dir_loop] == 1: if possible_transitions[dir_loop] == 1:
hoa, hsa, v = self._explore(handle, new_position, new_direction, depth + 1) hoa, hsa, v, d, min_dist = self._explore(handle,
visited.append(v) distance_map,
get_new_position(new_position, dir_loop),
dir_loop,
depth + 1)
if np.math.isinf(min_dist) == False:
visited_min_distance = min(visited_min_distance, min_dist)
visited = visited + v
visited_direction = visited_direction + d
has_opp_agent = 0.5 * (has_opp_agent + hoa) has_opp_agent = 0.5 * (has_opp_agent + hoa)
has_same_agent = 0.5 * (has_same_agent + hsa) has_same_agent = 0.5 * (has_same_agent + hsa)
return has_opp_agent, has_same_agent, visited return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance
else: else:
new_direction = fast_argmax(possible_transitions) new_direction = fast_argmax(possible_transitions)
new_position = get_new_position(new_position, new_direction) new_position = get_new_position(new_position, new_direction)
return has_opp_agent, has_same_agent, visited return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance
def get(self, handle): def get(self, handle):
# all values are [0,1] # all values are [0,1]
...@@ -285,6 +294,7 @@ class Extra(ObservationBuilder): ...@@ -285,6 +294,7 @@ class Extra(ObservationBuilder):
observation = np.zeros(self.observation_dim) observation = np.zeros(self.observation_dim)
visited = [] visited = []
visited_direction = []
agent = self.env.agents[handle] agent = self.env.agents[handle]
agent_done = False agent_done = False
...@@ -301,6 +311,7 @@ class Extra(ObservationBuilder): ...@@ -301,6 +311,7 @@ class Extra(ObservationBuilder):
if not agent_done: if not agent_done:
visited.append(agent_virtual_position) visited.append(agent_virtual_position)
visited_direction.append(agent.direction)
distance_map = self.env.distance_map.get() distance_map = self.env.distance_map.get()
current_cell_dist = distance_map[handle, current_cell_dist = distance_map[handle,
agent_virtual_position[0], agent_virtual_position[1], agent_virtual_position[0], agent_virtual_position[1],
...@@ -319,8 +330,12 @@ class Extra(ObservationBuilder): ...@@ -319,8 +330,12 @@ class Extra(ObservationBuilder):
if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)): if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)):
observation[dir_loop] = int(new_cell_dist < current_cell_dist) observation[dir_loop] = int(new_cell_dist < current_cell_dist)
has_opp_agent, has_same_agent, v = self._explore(handle, new_position, branch_direction) has_opp_agent, has_same_agent, vis, dir, min_dist = self._explore(handle,
visited.append(v) distance_map,
new_position,
branch_direction)
visited = visited + vis
visited_direction = visited_direction + dir
observation[10 + dir_loop] = 1 observation[10 + dir_loop] = 1
observation[14 + dir_loop] = has_opp_agent observation[14 + dir_loop] = has_opp_agent
...@@ -334,6 +349,16 @@ class Extra(ObservationBuilder): ...@@ -334,6 +349,16 @@ class Extra(ObservationBuilder):
observation[8] = int(agents_near_to_switch) observation[8] = int(agents_near_to_switch)
observation[9] = int(agents_near_to_switch_all) observation[9] = int(agents_near_to_switch_all)
observation[22] = int(self.env._elapsed_steps % 4 == 0)
observation[23] = int(self.env._elapsed_steps % 4 == 1)
observation[24] = int(self.env._elapsed_steps % 4 == 2)
observation[25] = int(self.env._elapsed_steps % 4 == 3)
observation[26] = int(agent.direction % 4 == 0)
observation[27] = int(agent.direction % 4 == 1)
observation[28] = int(agent.direction % 4 == 2)
observation[29] = int(agent.direction % 4 == 3)
self.env.dev_obs_dict.update({handle: visited}) self.env.dev_obs_dict.update({handle: visited})
return observation return observation
...@@ -349,3 +374,13 @@ class Extra(ObservationBuilder): ...@@ -349,3 +374,13 @@ class Extra(ObservationBuilder):
action_dict[a] = RailEnvActions.DO_NOTHING action_dict[a] = RailEnvActions.DO_NOTHING
return action_dict return action_dict
def loadAgent(self):
if self.agent is not None:
return
self.state_size = self.env.obs_builder.observation_dim
self.action_size = 5
print("action_size: ", self.action_size)
print("state_size: ", self.state_size)
self.agent = Agent(self.state_size, self.action_size, 0)
self.agent.load('./checkpoints/', 0, 1.0)
\ No newline at end of file
...@@ -3,7 +3,7 @@ import torch.nn.functional as F ...@@ -3,7 +3,7 @@ import torch.nn.functional as F
class PolicyNetwork(nn.Module): class PolicyNetwork(nn.Module):
def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128, hidsize3=32): def __init__(self, state_size, action_size, hidsize1=128, hidsize2=256, hidsize3=32):
super().__init__() super().__init__()
self.fc1 = nn.Linear(state_size, hidsize1) self.fc1 = nn.Linear(state_size, hidsize1)
self.fc2 = nn.Linear(hidsize1, hidsize2) self.fc2 = nn.Linear(hidsize1, hidsize2)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment