Commit 482a73e5 authored by Erik Nygren's avatar Erik Nygren
Browse files

updated curve calculation

parent 2f7ee5c1
Pipeline #458 failed with stage
in 1 minute and 47 seconds
......@@ -163,9 +163,9 @@ def a_star(rail_trans, rail_array, start, end):
for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1])
if node_pos[0] >= rail_shape[0] or \
node_pos[0] < 0 or \
node_pos[1] >= rail_shape[1] or \
node_pos[1] < 0:
node_pos[0] < 0 or \
node_pos[1] >= rail_shape[1] or \
node_pos[1] < 0:
continue
# validate positions
......@@ -232,7 +232,7 @@ def connect_rail(rail_trans, rail_array, start, end):
end_pos = path[-1]
for index in range(len(path) - 1):
current_pos = path[index]
new_pos = path[index+1]
new_pos = path[index + 1]
new_dir = get_direction(current_pos, new_pos)
new_trans = rail_array[current_pos]
......@@ -359,6 +359,7 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
# print("too close:", dist, sg_new[i], sg[j])
return False
return True
if check_all_dist(sg_new):
break
start_goal.append([start, goal])
......@@ -394,6 +395,7 @@ def rail_from_manual_specifications_generator(rail_spec):
Generator function that always returns a GridTransitionMap object with
the matrix of correct 16-bit bitmaps for each cell.
"""
def generator(width, height, num_resets=0):
t_utils = RailEnvTransitions()
......@@ -429,6 +431,7 @@ def rail_from_GridTransitionMap_generator(rail_map):
function
Generator function that always returns the given `rail_map' object.
"""
def generator(width, height, num_resets=0):
return rail_map
......@@ -449,6 +452,7 @@ def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames):
function
Generator function that always returns the given `rail_map' object.
"""
def generator(width, height, num_resets=0):
t_utils = RailEnvTransitions()
rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils)
......@@ -525,9 +529,9 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
# add all rotations
for rot in [0, 90, 180, 270]:
transitions_templates_.append((template,
t_utils.rotate_transition(
t_utils.transitions[i],
rot)))
t_utils.rotate_transition(
t_utils.transitions[i],
rot)))
transition_probabilities.append(transition_probability[i])
template = [template[-1]] + template[:-1]
......@@ -537,7 +541,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
is_match = True
for j in range(4):
if template[j] >= 0 and \
template[j] != transitions_templates_[i][0][j]:
template[j] != transitions_templates_[i][0][j]:
is_match = False
break
if is_match:
......@@ -678,7 +682,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
neigh_trans = rail[r][1]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1)
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & 1)
if max_bit:
rail[r][0] = t_utils.rotate_transition(int('0010000000000000', 2), 270)
......@@ -690,7 +694,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
neigh_trans = rail[r][-2]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1)
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 2))
if max_bit:
rail[r][-1] = t_utils.rotate_transition(int('0010000000000000', 2),
......@@ -704,7 +708,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
neigh_trans = rail[1][c]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1)
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 3))
if max_bit:
rail[0][c] = int('0010000000000000', 2)
......@@ -716,7 +720,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
neigh_trans = rail[-2][c]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1)
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 1))
if max_bit:
rail[-1][c] = t_utils.rotate_transition(int('0010000000000000', 2), 180)
......@@ -840,9 +844,9 @@ class RailEnv(Environment):
def check_agent_lists(self):
for lAgents, name in zip(
[self.agents_handles, self.agents_position, self.agents_direction],
["handles", "positions", "directions"]):
assert self.number_of_agents == len(lAgents), "Inconsistent agent list:"+name
[self.agents_handles, self.agents_position, self.agents_direction],
["handles", "positions", "directions"]):
assert self.number_of_agents == len(lAgents), "Inconsistent agent list:" + name
def check_agent_locdirpath(self, iAgent):
valid_movements = []
......@@ -857,7 +861,7 @@ class RailEnv(Environment):
for m in valid_movements:
new_position = self._new_position(self.agents_position[iAgent], m[1])
if m[0] not in valid_starting_directions and \
self._path_exists(new_position, m[0], self.agents_target[iAgent]):
self._path_exists(new_position, m[0], self.agents_target[iAgent]):
valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0:
......@@ -876,7 +880,7 @@ class RailEnv(Environment):
for m in valid_movements:
new_position = self._new_position(rcPos, m[1])
if m[0] not in valid_starting_directions and \
self._path_exists(new_position, m[0], rcTarget):
self._path_exists(new_position, m[0], rcTarget):
valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0:
......@@ -891,7 +895,7 @@ class RailEnv(Environment):
rcPos = np.random.choice(len(self.valid_positions))
iAgent = self.number_of_agents
self.agents_position.append(tuple(rcPos)) # ensure it's a tuple not a list
self.agents_handles.append(max(self.agents_handles + [-1]) + 1) # max(handles) + 1, starting at 0
......@@ -902,7 +906,7 @@ class RailEnv(Environment):
self.number_of_agents += 1
self.check_agent_lists()
return iAgent
def reset(self, regen_rail=True, replace_agents=True):
if regen_rail or self.rail is None:
# TODO: Import not only rail information but also start and goal positions
......@@ -961,7 +965,7 @@ class RailEnv(Environment):
for m in valid_movements:
new_position = self._new_position(self.agents_position[i], m[1])
if m[0] not in valid_starting_directions and \
self._path_exists(new_position, m[0], self.agents_target[i]):
self._path_exists(new_position, m[0], self.agents_target[i]):
valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0:
......@@ -1011,6 +1015,15 @@ class RailEnv(Environment):
pos = self.agents_position[i]
direction = self.agents_direction[i]
# compute number of possible transitions in the current
# cell used to check for invalid actions
nbits = 0
tmp = self.rail.get_transitions((pos[0], pos[1]))
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
movement = direction
if action == 1:
movement = direction - 1
......@@ -1024,14 +1037,6 @@ class RailEnv(Environment):
is_deadend = False
if action == 2:
# compute number of possible transitions in the current
# cell
nbits = 0
tmp = self.rail.get_transitions((pos[0], pos[1]))
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1:
# dead-end; assuming the rail network is consistent,
# this should match the direction the agent has come
......@@ -1074,9 +1079,9 @@ class RailEnv(Environment):
# Is it a legal move? 1) transition allows the movement in the
# cell, 2) the new cell is not empty (case 0), 3) the cell is
# free, i.e., no agent is currently in that cell
if new_position[1] >= self.width or\
new_position[0] >= self.height or\
new_position[0] < 0 or new_position[1] < 0:
if new_position[1] >= self.width or \
new_position[0] >= self.height or \
new_position[0] < 0 or new_position[1] < 0:
new_cell_isValid = False
elif self.rail.get_transitions((new_position[0], new_position[1])) > 0:
......@@ -1105,7 +1110,7 @@ class RailEnv(Environment):
# if agent is not in target position, add step penalty
if self.agents_position[i][0] == self.agents_target[i][0] and \
self.agents_position[i][1] == self.agents_target[i][1]:
self.agents_position[i][1] == self.agents_target[i][1]:
self.dones[handle] = True
else:
self.rewards_dict[handle] += step_penalty
......@@ -1114,7 +1119,7 @@ class RailEnv(Environment):
num_agents_in_target_position = 0
for i in range(self.number_of_agents):
if self.agents_position[i][0] == self.agents_target[i][0] and \
self.agents_position[i][1] == self.agents_target[i][1]:
self.agents_position[i][1] == self.agents_target[i][1]:
num_agents_in_target_position += 1
if num_agents_in_target_position == self.number_of_agents:
......@@ -1127,7 +1132,7 @@ class RailEnv(Environment):
return self._get_observations(), self.rewards_dict, self.dones, {}
def _new_position(self, position, movement):
if movement == 0: # NORTH
if movement == 0: # NORTH
return (position[0] - 1, position[1])
elif movement == 1: # EAST
return (position[0], position[1] + 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment