Skip to content
Snippets Groups Projects
Commit 482a73e5 authored by Erik Nygren's avatar Erik Nygren
Browse files

updated curve calculation

parent 2f7ee5c1
No related branches found
No related tags found
No related merge requests found
...@@ -163,9 +163,9 @@ def a_star(rail_trans, rail_array, start, end): ...@@ -163,9 +163,9 @@ def a_star(rail_trans, rail_array, start, end):
for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]: for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1]) node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1])
if node_pos[0] >= rail_shape[0] or \ if node_pos[0] >= rail_shape[0] or \
node_pos[0] < 0 or \ node_pos[0] < 0 or \
node_pos[1] >= rail_shape[1] or \ node_pos[1] >= rail_shape[1] or \
node_pos[1] < 0: node_pos[1] < 0:
continue continue
# validate positions # validate positions
...@@ -232,7 +232,7 @@ def connect_rail(rail_trans, rail_array, start, end): ...@@ -232,7 +232,7 @@ def connect_rail(rail_trans, rail_array, start, end):
end_pos = path[-1] end_pos = path[-1]
for index in range(len(path) - 1): for index in range(len(path) - 1):
current_pos = path[index] current_pos = path[index]
new_pos = path[index+1] new_pos = path[index + 1]
new_dir = get_direction(current_pos, new_pos) new_dir = get_direction(current_pos, new_pos)
new_trans = rail_array[current_pos] new_trans = rail_array[current_pos]
...@@ -359,6 +359,7 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): ...@@ -359,6 +359,7 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
# print("too close:", dist, sg_new[i], sg[j]) # print("too close:", dist, sg_new[i], sg[j])
return False return False
return True return True
if check_all_dist(sg_new): if check_all_dist(sg_new):
break break
start_goal.append([start, goal]) start_goal.append([start, goal])
...@@ -394,6 +395,7 @@ def rail_from_manual_specifications_generator(rail_spec): ...@@ -394,6 +395,7 @@ def rail_from_manual_specifications_generator(rail_spec):
Generator function that always returns a GridTransitionMap object with Generator function that always returns a GridTransitionMap object with
the matrix of correct 16-bit bitmaps for each cell. the matrix of correct 16-bit bitmaps for each cell.
""" """
def generator(width, height, num_resets=0): def generator(width, height, num_resets=0):
t_utils = RailEnvTransitions() t_utils = RailEnvTransitions()
...@@ -429,6 +431,7 @@ def rail_from_GridTransitionMap_generator(rail_map): ...@@ -429,6 +431,7 @@ def rail_from_GridTransitionMap_generator(rail_map):
function function
Generator function that always returns the given `rail_map' object. Generator function that always returns the given `rail_map' object.
""" """
def generator(width, height, num_resets=0): def generator(width, height, num_resets=0):
return rail_map return rail_map
...@@ -449,6 +452,7 @@ def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames): ...@@ -449,6 +452,7 @@ def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames):
function function
Generator function that always returns the given `rail_map' object. Generator function that always returns the given `rail_map' object.
""" """
def generator(width, height, num_resets=0): def generator(width, height, num_resets=0):
t_utils = RailEnvTransitions() t_utils = RailEnvTransitions()
rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils) rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils)
...@@ -525,9 +529,9 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): ...@@ -525,9 +529,9 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
# add all rotations # add all rotations
for rot in [0, 90, 180, 270]: for rot in [0, 90, 180, 270]:
transitions_templates_.append((template, transitions_templates_.append((template,
t_utils.rotate_transition( t_utils.rotate_transition(
t_utils.transitions[i], t_utils.transitions[i],
rot))) rot)))
transition_probabilities.append(transition_probability[i]) transition_probabilities.append(transition_probability[i])
template = [template[-1]] + template[:-1] template = [template[-1]] + template[:-1]
...@@ -537,7 +541,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): ...@@ -537,7 +541,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
is_match = True is_match = True
for j in range(4): for j in range(4):
if template[j] >= 0 and \ if template[j] >= 0 and \
template[j] != transitions_templates_[i][0][j]: template[j] != transitions_templates_[i][0][j]:
is_match = False is_match = False
break break
if is_match: if is_match:
...@@ -678,7 +682,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): ...@@ -678,7 +682,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
neigh_trans = rail[r][1] neigh_trans = rail[r][1]
if neigh_trans is not None: if neigh_trans is not None:
for k in range(4): for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1) neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & 1) max_bit = max_bit | (neigh_trans_from_direction & 1)
if max_bit: if max_bit:
rail[r][0] = t_utils.rotate_transition(int('0010000000000000', 2), 270) rail[r][0] = t_utils.rotate_transition(int('0010000000000000', 2), 270)
...@@ -690,7 +694,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): ...@@ -690,7 +694,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
neigh_trans = rail[r][-2] neigh_trans = rail[r][-2]
if neigh_trans is not None: if neigh_trans is not None:
for k in range(4): for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1) neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 2)) max_bit = max_bit | (neigh_trans_from_direction & (1 << 2))
if max_bit: if max_bit:
rail[r][-1] = t_utils.rotate_transition(int('0010000000000000', 2), rail[r][-1] = t_utils.rotate_transition(int('0010000000000000', 2),
...@@ -704,7 +708,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): ...@@ -704,7 +708,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
neigh_trans = rail[1][c] neigh_trans = rail[1][c]
if neigh_trans is not None: if neigh_trans is not None:
for k in range(4): for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1) neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 3)) max_bit = max_bit | (neigh_trans_from_direction & (1 << 3))
if max_bit: if max_bit:
rail[0][c] = int('0010000000000000', 2) rail[0][c] = int('0010000000000000', 2)
...@@ -716,7 +720,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): ...@@ -716,7 +720,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
neigh_trans = rail[-2][c] neigh_trans = rail[-2][c]
if neigh_trans is not None: if neigh_trans is not None:
for k in range(4): for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1) neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 1)) max_bit = max_bit | (neigh_trans_from_direction & (1 << 1))
if max_bit: if max_bit:
rail[-1][c] = t_utils.rotate_transition(int('0010000000000000', 2), 180) rail[-1][c] = t_utils.rotate_transition(int('0010000000000000', 2), 180)
...@@ -840,9 +844,9 @@ class RailEnv(Environment): ...@@ -840,9 +844,9 @@ class RailEnv(Environment):
def check_agent_lists(self): def check_agent_lists(self):
for lAgents, name in zip( for lAgents, name in zip(
[self.agents_handles, self.agents_position, self.agents_direction], [self.agents_handles, self.agents_position, self.agents_direction],
["handles", "positions", "directions"]): ["handles", "positions", "directions"]):
assert self.number_of_agents == len(lAgents), "Inconsistent agent list:"+name assert self.number_of_agents == len(lAgents), "Inconsistent agent list:" + name
def check_agent_locdirpath(self, iAgent): def check_agent_locdirpath(self, iAgent):
valid_movements = [] valid_movements = []
...@@ -857,7 +861,7 @@ class RailEnv(Environment): ...@@ -857,7 +861,7 @@ class RailEnv(Environment):
for m in valid_movements: for m in valid_movements:
new_position = self._new_position(self.agents_position[iAgent], m[1]) new_position = self._new_position(self.agents_position[iAgent], m[1])
if m[0] not in valid_starting_directions and \ if m[0] not in valid_starting_directions and \
self._path_exists(new_position, m[0], self.agents_target[iAgent]): self._path_exists(new_position, m[0], self.agents_target[iAgent]):
valid_starting_directions.append(m[0]) valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0: if len(valid_starting_directions) == 0:
...@@ -876,7 +880,7 @@ class RailEnv(Environment): ...@@ -876,7 +880,7 @@ class RailEnv(Environment):
for m in valid_movements: for m in valid_movements:
new_position = self._new_position(rcPos, m[1]) new_position = self._new_position(rcPos, m[1])
if m[0] not in valid_starting_directions and \ if m[0] not in valid_starting_directions and \
self._path_exists(new_position, m[0], rcTarget): self._path_exists(new_position, m[0], rcTarget):
valid_starting_directions.append(m[0]) valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0: if len(valid_starting_directions) == 0:
...@@ -891,7 +895,7 @@ class RailEnv(Environment): ...@@ -891,7 +895,7 @@ class RailEnv(Environment):
rcPos = np.random.choice(len(self.valid_positions)) rcPos = np.random.choice(len(self.valid_positions))
iAgent = self.number_of_agents iAgent = self.number_of_agents
self.agents_position.append(tuple(rcPos)) # ensure it's a tuple not a list self.agents_position.append(tuple(rcPos)) # ensure it's a tuple not a list
self.agents_handles.append(max(self.agents_handles + [-1]) + 1) # max(handles) + 1, starting at 0 self.agents_handles.append(max(self.agents_handles + [-1]) + 1) # max(handles) + 1, starting at 0
...@@ -902,7 +906,7 @@ class RailEnv(Environment): ...@@ -902,7 +906,7 @@ class RailEnv(Environment):
self.number_of_agents += 1 self.number_of_agents += 1
self.check_agent_lists() self.check_agent_lists()
return iAgent return iAgent
def reset(self, regen_rail=True, replace_agents=True): def reset(self, regen_rail=True, replace_agents=True):
if regen_rail or self.rail is None: if regen_rail or self.rail is None:
# TODO: Import not only rail information but also start and goal positions # TODO: Import not only rail information but also start and goal positions
...@@ -961,7 +965,7 @@ class RailEnv(Environment): ...@@ -961,7 +965,7 @@ class RailEnv(Environment):
for m in valid_movements: for m in valid_movements:
new_position = self._new_position(self.agents_position[i], m[1]) new_position = self._new_position(self.agents_position[i], m[1])
if m[0] not in valid_starting_directions and \ if m[0] not in valid_starting_directions and \
self._path_exists(new_position, m[0], self.agents_target[i]): self._path_exists(new_position, m[0], self.agents_target[i]):
valid_starting_directions.append(m[0]) valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0: if len(valid_starting_directions) == 0:
...@@ -1011,6 +1015,15 @@ class RailEnv(Environment): ...@@ -1011,6 +1015,15 @@ class RailEnv(Environment):
pos = self.agents_position[i] pos = self.agents_position[i]
direction = self.agents_direction[i] direction = self.agents_direction[i]
# compute number of possible transitions in the current
# cell used to check for invalid actions
nbits = 0
tmp = self.rail.get_transitions((pos[0], pos[1]))
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
movement = direction movement = direction
if action == 1: if action == 1:
movement = direction - 1 movement = direction - 1
...@@ -1024,14 +1037,6 @@ class RailEnv(Environment): ...@@ -1024,14 +1037,6 @@ class RailEnv(Environment):
is_deadend = False is_deadend = False
if action == 2: if action == 2:
# compute number of possible transitions in the current
# cell
nbits = 0
tmp = self.rail.get_transitions((pos[0], pos[1]))
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1: if nbits == 1:
# dead-end; assuming the rail network is consistent, # dead-end; assuming the rail network is consistent,
# this should match the direction the agent has come # this should match the direction the agent has come
...@@ -1074,9 +1079,9 @@ class RailEnv(Environment): ...@@ -1074,9 +1079,9 @@ class RailEnv(Environment):
# Is it a legal move? 1) transition allows the movement in the # Is it a legal move? 1) transition allows the movement in the
# cell, 2) the new cell is not empty (case 0), 3) the cell is # cell, 2) the new cell is not empty (case 0), 3) the cell is
# free, i.e., no agent is currently in that cell # free, i.e., no agent is currently in that cell
if new_position[1] >= self.width or\ if new_position[1] >= self.width or \
new_position[0] >= self.height or\ new_position[0] >= self.height or \
new_position[0] < 0 or new_position[1] < 0: new_position[0] < 0 or new_position[1] < 0:
new_cell_isValid = False new_cell_isValid = False
elif self.rail.get_transitions((new_position[0], new_position[1])) > 0: elif self.rail.get_transitions((new_position[0], new_position[1])) > 0:
...@@ -1105,7 +1110,7 @@ class RailEnv(Environment): ...@@ -1105,7 +1110,7 @@ class RailEnv(Environment):
# if agent is not in target position, add step penalty # if agent is not in target position, add step penalty
if self.agents_position[i][0] == self.agents_target[i][0] and \ if self.agents_position[i][0] == self.agents_target[i][0] and \
self.agents_position[i][1] == self.agents_target[i][1]: self.agents_position[i][1] == self.agents_target[i][1]:
self.dones[handle] = True self.dones[handle] = True
else: else:
self.rewards_dict[handle] += step_penalty self.rewards_dict[handle] += step_penalty
...@@ -1114,7 +1119,7 @@ class RailEnv(Environment): ...@@ -1114,7 +1119,7 @@ class RailEnv(Environment):
num_agents_in_target_position = 0 num_agents_in_target_position = 0
for i in range(self.number_of_agents): for i in range(self.number_of_agents):
if self.agents_position[i][0] == self.agents_target[i][0] and \ if self.agents_position[i][0] == self.agents_target[i][0] and \
self.agents_position[i][1] == self.agents_target[i][1]: self.agents_position[i][1] == self.agents_target[i][1]:
num_agents_in_target_position += 1 num_agents_in_target_position += 1
if num_agents_in_target_position == self.number_of_agents: if num_agents_in_target_position == self.number_of_agents:
...@@ -1127,7 +1132,7 @@ class RailEnv(Environment): ...@@ -1127,7 +1132,7 @@ class RailEnv(Environment):
return self._get_observations(), self.rewards_dict, self.dones, {} return self._get_observations(), self.rewards_dict, self.dones, {}
def _new_position(self, position, movement): def _new_position(self, position, movement):
if movement == 0: # NORTH if movement == 0: # NORTH
return (position[0] - 1, position[1]) return (position[0] - 1, position[1])
elif movement == 1: # EAST elif movement == 1: # EAST
return (position[0], position[1] + 1) return (position[0], position[1] + 1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment