Skip to content
Snippets Groups Projects
Commit 8b162e9c authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

0.162

parent 662a63bb
No related branches found
No related tags found
No related merge requests found
Showing
with 23 additions and 48 deletions
File deleted
File deleted
File deleted
File deleted
File deleted
DQN checkpoints will be saved here
File deleted
File deleted
File deleted
File deleted
No preview for this file type
No preview for this file type
No preview for this file type
dump.rdb 0 → 100644
File added
......@@ -66,8 +66,10 @@ def fast_count_nonzero(possible_transitions: (int, int, int, int)):
class Extra(ObservationBuilder):
def __init__(self, max_depth):
self.max_depth = max_depth
self.observation_dim = 30
self.max_depth = max_depthmodel_checkpoint.meta
model_checkpoint.optimizer
model_checkpoint.policy
self.observation_dim = 22
self.agent = None
def build_data(self):
......@@ -188,6 +190,10 @@ class Extra(ObservationBuilder):
return obsData
def is_collision(self, obsData):
if np.sum(obsData[10:14]) == 0:
return False
if np.sum(obsData[10:14]) == np.sum(obsData[14:18]):
return True
return False
def reset(self):
......@@ -203,16 +209,14 @@ class Extra(ObservationBuilder):
return 2
return 3
def _explore(self, handle, distance_map, new_position, new_direction, depth=0):
def _explore(self, handle, new_position, new_direction, depth=0):
has_opp_agent = 0
has_same_agent = 0
visited = []
visited_direction = []
visited_min_distance = np.inf
# stop exploring (max_depth reached)
if depth >= self.max_depth:
return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance
return has_opp_agent, has_same_agent, visited
# max_explore_steps = 100
cnt = 0
......@@ -220,22 +224,15 @@ class Extra(ObservationBuilder):
cnt += 1
visited.append(new_position)
visited_direction.append(new_direction)
new_cell_dist = distance_map[handle,
new_position[0], new_position[1],
new_direction]
visited_min_distance = min(visited_min_distance, new_cell_dist)
opp_a = self.env.agent_positions[new_position]
if opp_a != -1 and opp_a != handle:
if self.env.agents[opp_a].direction != new_direction:
# opp agent found
has_opp_agent = 1
return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance
return has_opp_agent, has_same_agent, visited
else:
has_same_agent = 1
return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance
return has_opp_agent, has_same_agent, visited
# convert one-hot encoding to 0,1,2,3
possible_transitions = self.env.rail.get_transitions(*new_position, new_direction)
......@@ -244,28 +241,23 @@ class Extra(ObservationBuilder):
agents_near_to_switch_all = \
self.check_agent_descision(new_position, new_direction)
if agents_near_to_switch:
return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance
return has_opp_agent, has_same_agent, visited
if agents_on_switch:
for dir_loop in range(4):
if possible_transitions[dir_loop] == 1:
hoa, hsa, v, d, min_dist = self._explore(handle,
distance_map,
get_new_position(new_position, dir_loop),
dir_loop,
depth + 1)
if np.math.isinf(min_dist) == False:
visited_min_distance = min(visited_min_distance, min_dist)
visited = visited + v
visited_direction = visited_direction + d
hoa, hsa, v = self._explore(handle,
get_new_position(new_position, dir_loop),
dir_loop,
depth + 1)
visited.append(v)
has_opp_agent = 0.5 * (has_opp_agent + hoa)
has_same_agent = 0.5 * (has_same_agent + hsa)
return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance
return has_opp_agent, has_same_agent, visited
else:
new_direction = fast_argmax(possible_transitions)
new_position = get_new_position(new_position, new_direction)
return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance
return has_opp_agent, has_same_agent, visited
def get(self, handle):
# all values are [0,1]
......@@ -294,7 +286,6 @@ class Extra(ObservationBuilder):
observation = np.zeros(self.observation_dim)
visited = []
visited_direction = []
agent = self.env.agents[handle]
agent_done = False
......@@ -311,7 +302,6 @@ class Extra(ObservationBuilder):
if not agent_done:
visited.append(agent_virtual_position)
visited_direction.append(agent.direction)
distance_map = self.env.distance_map.get()
current_cell_dist = distance_map[handle,
agent_virtual_position[0], agent_virtual_position[1],
......@@ -330,12 +320,8 @@ class Extra(ObservationBuilder):
if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)):
observation[dir_loop] = int(new_cell_dist < current_cell_dist)
has_opp_agent, has_same_agent, vis, dir, min_dist = self._explore(handle,
distance_map,
new_position,
branch_direction)
visited = visited + vis
visited_direction = visited_direction + dir
has_opp_agent, has_same_agent, v = self._explore(handle, new_position, branch_direction)
visited.append(v)
observation[10 + dir_loop] = 1
observation[14 + dir_loop] = has_opp_agent
......@@ -349,16 +335,6 @@ class Extra(ObservationBuilder):
observation[8] = int(agents_near_to_switch)
observation[9] = int(agents_near_to_switch_all)
observation[22] = int(self.env._elapsed_steps % 4 == 0)
observation[23] = int(self.env._elapsed_steps % 4 == 1)
observation[24] = int(self.env._elapsed_steps % 4 == 2)
observation[25] = int(self.env._elapsed_steps % 4 == 3)
observation[26] = int(agent.direction % 4 == 0)
observation[27] = int(agent.direction % 4 == 1)
observation[28] = int(agent.direction % 4 == 2)
observation[29] = int(agent.direction % 4 == 3)
self.env.dev_obs_dict.update({handle: visited})
return observation
......
......@@ -3,7 +3,7 @@ import torch.nn.functional as F
class PolicyNetwork(nn.Module):
def __init__(self, state_size, action_size, hidsize1=128, hidsize2=256, hidsize3=32):
def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128, hidsize3=32):
super().__init__()
self.fc1 = nn.Linear(state_size, hidsize1)
self.fc2 = nn.Linear(hidsize1, hidsize2)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment