From 482a73e5379ac6c0506b35d951eab47f2322a299 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Wed, 1 May 2019 08:19:12 +0200 Subject: [PATCH] updated curve calculation --- flatland/envs/rail_env.py | 73 +++++++++++++++++++++------------------ 1 file changed, 39 insertions(+), 34 deletions(-) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 552e5f39..55e0e2ed 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -163,9 +163,9 @@ def a_star(rail_trans, rail_array, start, end): for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]: node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1]) if node_pos[0] >= rail_shape[0] or \ - node_pos[0] < 0 or \ - node_pos[1] >= rail_shape[1] or \ - node_pos[1] < 0: + node_pos[0] < 0 or \ + node_pos[1] >= rail_shape[1] or \ + node_pos[1] < 0: continue # validate positions @@ -232,7 +232,7 @@ def connect_rail(rail_trans, rail_array, start, end): end_pos = path[-1] for index in range(len(path) - 1): current_pos = path[index] - new_pos = path[index+1] + new_pos = path[index + 1] new_dir = get_direction(current_pos, new_pos) new_trans = rail_array[current_pos] @@ -359,6 +359,7 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): # print("too close:", dist, sg_new[i], sg[j]) return False return True + if check_all_dist(sg_new): break start_goal.append([start, goal]) @@ -394,6 +395,7 @@ def rail_from_manual_specifications_generator(rail_spec): Generator function that always returns a GridTransitionMap object with the matrix of correct 16-bit bitmaps for each cell. """ + def generator(width, height, num_resets=0): t_utils = RailEnvTransitions() @@ -429,6 +431,7 @@ def rail_from_GridTransitionMap_generator(rail_map): function Generator function that always returns the given `rail_map' object. """ + def generator(width, height, num_resets=0): return rail_map @@ -449,6 +452,7 @@ def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames): function Generator function that always returns the given `rail_map' object. """ + def generator(width, height, num_resets=0): t_utils = RailEnvTransitions() rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils) @@ -525,9 +529,9 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): # add all rotations for rot in [0, 90, 180, 270]: transitions_templates_.append((template, - t_utils.rotate_transition( - t_utils.transitions[i], - rot))) + t_utils.rotate_transition( + t_utils.transitions[i], + rot))) transition_probabilities.append(transition_probability[i]) template = [template[-1]] + template[:-1] @@ -537,7 +541,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): is_match = True for j in range(4): if template[j] >= 0 and \ - template[j] != transitions_templates_[i][0][j]: + template[j] != transitions_templates_[i][0][j]: is_match = False break if is_match: @@ -678,7 +682,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): neigh_trans = rail[r][1] if neigh_trans is not None: for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1) + neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1) max_bit = max_bit | (neigh_trans_from_direction & 1) if max_bit: rail[r][0] = t_utils.rotate_transition(int('0010000000000000', 2), 270) @@ -690,7 +694,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): neigh_trans = rail[r][-2] if neigh_trans is not None: for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1) + neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1) max_bit = max_bit | (neigh_trans_from_direction & (1 << 2)) if max_bit: rail[r][-1] = t_utils.rotate_transition(int('0010000000000000', 2), @@ -704,7 +708,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): neigh_trans = rail[1][c] if neigh_trans is not None: for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1) + neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1) max_bit = max_bit | (neigh_trans_from_direction & (1 << 3)) if max_bit: rail[0][c] = int('0010000000000000', 2) @@ -716,7 +720,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): neigh_trans = rail[-2][c] if neigh_trans is not None: for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1) + neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1) max_bit = max_bit | (neigh_trans_from_direction & (1 << 1)) if max_bit: rail[-1][c] = t_utils.rotate_transition(int('0010000000000000', 2), 180) @@ -840,9 +844,9 @@ class RailEnv(Environment): def check_agent_lists(self): for lAgents, name in zip( - [self.agents_handles, self.agents_position, self.agents_direction], - ["handles", "positions", "directions"]): - assert self.number_of_agents == len(lAgents), "Inconsistent agent list:"+name + [self.agents_handles, self.agents_position, self.agents_direction], + ["handles", "positions", "directions"]): + assert self.number_of_agents == len(lAgents), "Inconsistent agent list:" + name def check_agent_locdirpath(self, iAgent): valid_movements = [] @@ -857,7 +861,7 @@ class RailEnv(Environment): for m in valid_movements: new_position = self._new_position(self.agents_position[iAgent], m[1]) if m[0] not in valid_starting_directions and \ - self._path_exists(new_position, m[0], self.agents_target[iAgent]): + self._path_exists(new_position, m[0], self.agents_target[iAgent]): valid_starting_directions.append(m[0]) if len(valid_starting_directions) == 0: @@ -876,7 +880,7 @@ class RailEnv(Environment): for m in valid_movements: new_position = self._new_position(rcPos, m[1]) if m[0] not in valid_starting_directions and \ - self._path_exists(new_position, m[0], rcTarget): + self._path_exists(new_position, m[0], rcTarget): valid_starting_directions.append(m[0]) if len(valid_starting_directions) == 0: @@ -891,7 +895,7 @@ class RailEnv(Environment): rcPos = np.random.choice(len(self.valid_positions)) iAgent = self.number_of_agents - + self.agents_position.append(tuple(rcPos)) # ensure it's a tuple not a list self.agents_handles.append(max(self.agents_handles + [-1]) + 1) # max(handles) + 1, starting at 0 @@ -902,7 +906,7 @@ class RailEnv(Environment): self.number_of_agents += 1 self.check_agent_lists() return iAgent - + def reset(self, regen_rail=True, replace_agents=True): if regen_rail or self.rail is None: # TODO: Import not only rail information but also start and goal positions @@ -961,7 +965,7 @@ class RailEnv(Environment): for m in valid_movements: new_position = self._new_position(self.agents_position[i], m[1]) if m[0] not in valid_starting_directions and \ - self._path_exists(new_position, m[0], self.agents_target[i]): + self._path_exists(new_position, m[0], self.agents_target[i]): valid_starting_directions.append(m[0]) if len(valid_starting_directions) == 0: @@ -1011,6 +1015,15 @@ class RailEnv(Environment): pos = self.agents_position[i] direction = self.agents_direction[i] + # compute number of possible transitions in the current + # cell used to check for invalid actions + + nbits = 0 + tmp = self.rail.get_transitions((pos[0], pos[1])) + while tmp > 0: + nbits += (tmp & 1) + tmp = tmp >> 1 + movement = direction if action == 1: movement = direction - 1 @@ -1024,14 +1037,6 @@ class RailEnv(Environment): is_deadend = False if action == 2: - - # compute number of possible transitions in the current - # cell - nbits = 0 - tmp = self.rail.get_transitions((pos[0], pos[1])) - while tmp > 0: - nbits += (tmp & 1) - tmp = tmp >> 1 if nbits == 1: # dead-end; assuming the rail network is consistent, # this should match the direction the agent has come @@ -1074,9 +1079,9 @@ class RailEnv(Environment): # Is it a legal move? 1) transition allows the movement in the # cell, 2) the new cell is not empty (case 0), 3) the cell is # free, i.e., no agent is currently in that cell - if new_position[1] >= self.width or\ - new_position[0] >= self.height or\ - new_position[0] < 0 or new_position[1] < 0: + if new_position[1] >= self.width or \ + new_position[0] >= self.height or \ + new_position[0] < 0 or new_position[1] < 0: new_cell_isValid = False elif self.rail.get_transitions((new_position[0], new_position[1])) > 0: @@ -1105,7 +1110,7 @@ class RailEnv(Environment): # if agent is not in target position, add step penalty if self.agents_position[i][0] == self.agents_target[i][0] and \ - self.agents_position[i][1] == self.agents_target[i][1]: + self.agents_position[i][1] == self.agents_target[i][1]: self.dones[handle] = True else: self.rewards_dict[handle] += step_penalty @@ -1114,7 +1119,7 @@ class RailEnv(Environment): num_agents_in_target_position = 0 for i in range(self.number_of_agents): if self.agents_position[i][0] == self.agents_target[i][0] and \ - self.agents_position[i][1] == self.agents_target[i][1]: + self.agents_position[i][1] == self.agents_target[i][1]: num_agents_in_target_position += 1 if num_agents_in_target_position == self.number_of_agents: @@ -1127,7 +1132,7 @@ class RailEnv(Environment): return self._get_observations(), self.rewards_dict, self.dones, {} def _new_position(self, position, movement): - if movement == 0: # NORTH + if movement == 0: # NORTH return (position[0] - 1, position[1]) elif movement == 1: # EAST return (position[0], position[1] + 1) -- GitLab