From 164be2699eb6f71c629b9b62ca7137c6af006be7 Mon Sep 17 00:00:00 2001 From: spiglerg <spiglerg@gmail.com> Date: Fri, 19 Apr 2019 15:36:08 +0200 Subject: [PATCH] added possibility to specify cell types relative proportions when generating random rails --- examples/temporary_example.py | 27 +- flatland/core/env_observation_builder.py | 10 +- flatland/envs/rail_env.py | 446 ++++++++++++----------- 3 files changed, 257 insertions(+), 226 deletions(-) diff --git a/examples/temporary_example.py b/examples/temporary_example.py index 03b5ebd0..52927160 100644 --- a/examples/temporary_example.py +++ b/examples/temporary_example.py @@ -10,13 +10,34 @@ random.seed(1) np.random.seed(1) """ +transition_probability = [1.0, # empty cell - Case 0 + 3.0, # Case 1 - straight + 1.0, # Case 2 - simple switch + 3.0, # Case 3 - diamond drossing + 2.0, # Case 4 - single slip + 1.0, # Case 5 - double slip + 1.0, # Case 6 - symmetrical + 1.0] # Case 7 - dead end +""" +transition_probability = [1.0, # empty cell - Case 0 + 1.0, # Case 1 - straight + 1.0, # Case 2 - simple switch + 1.0, # Case 3 - diamond drossing + 1.0, # Case 4 - single slip + 1.0, # Case 5 - double slip + 1.0, # Case 6 - symmetrical + 1.0] # Case 7 - dead end + # Example generate a random rail -env = RailEnv(width=20, height=20, rail_generator=random_rail_generator, number_of_agents=10) +env = RailEnv(width=20, + height=20, + rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), + number_of_agents=10) env.reset() env_renderer = RenderTool(env) env_renderer.renderEnv(show=True) -""" + # Example generate a rail given a manual specification, # a map of tuples (cell_type, rotation) @@ -26,7 +47,7 @@ specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)], env = RailEnv(width=6, height=2, rail_generator=rail_from_manual_specifications_generator(specs), - number_of_agents=2, + number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2)) handle = env.get_agent_handles() diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index 24f213ab..cd8a5309 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -72,7 +72,7 @@ class TreeObsForRailEnv(ObservationBuilder): # Update local lookup table for all agents' target locations self.location_has_target = {} for loc in self.env.agents_target: - self.location_has_target[(loc[0],loc[1])] = 1 + self.location_has_target[(loc[0], loc[1])] = 1 def _distance_map_walker(self, position, target_nr): """ @@ -292,8 +292,6 @@ class TreeObsForRailEnv(ObservationBuilder): if position in self.location_has_target: other_target_encountered = True - - # ############################# # ############################# @@ -354,10 +352,8 @@ class TreeObsForRailEnv(ObservationBuilder): 0, self.distance_map[handle, position[0], position[1], direction]] - # TODO: - # ############################# # ############################# @@ -368,9 +364,9 @@ class TreeObsForRailEnv(ObservationBuilder): (branch_direction+2) % 4): # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes # it back - new_cell = self._new_position(position, (branch_direction+2)%4) + new_cell = self._new_position(position, (branch_direction+2) % 4) - branch_observation = self._explore_branch(handle, new_cell, (branch_direction+2)%4, depth+1) + branch_observation = self._explore_branch(handle, new_cell, (branch_direction+2) % 4, depth+1) observation = observation + branch_observation elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction), diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 8b34bf3d..1664c315 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -84,7 +84,7 @@ def generate_rail_from_list_of_manual_specifications(list_of_specifications) """ -def random_rail_generator(width, height, num_resets=0): +def random_rail_generator(cell_type_relative_proportion=[1.0]*8): """ Dummy random level generator: - fill in cells at random in [width-2, height-2] @@ -116,234 +116,248 @@ def random_rail_generator(width, height, num_resets=0): The matrix with the correct 16-bit bitmaps for each cell. """ - t_utils = RailEnvTransitions() - - transitions_templates_ = [] - for i in range(len(t_utils.transitions)-1): # don't include dead-ends - all_transitions = 0 - for dir_ in range(4): - trans = t_utils.get_transitions(t_utils.transitions[i], dir_) - all_transitions |= (trans[0] << 3) | \ - (trans[1] << 2) | \ - (trans[2] << 1) | \ - (trans[3]) - - template = [int(x) for x in bin(all_transitions)[2:]] - template = [0]*(4-len(template)) + template - - # add all rotations - for rot in [0, 90, 180, 270]: - transitions_templates_.append((template, - t_utils.rotate_transition( - t_utils.transitions[i], - rot))) - template = [template[-1]]+template[:-1] - - def get_matching_templates(template): - ret = [] - for i in range(len(transitions_templates_)): - is_match = True - for j in range(4): - if template[j] >= 0 and \ - template[j] != transitions_templates_[i][0][j]: - is_match = False - break - if is_match: - ret.append(transitions_templates_[i][1]) - return ret - - MAX_INSERTIONS = (width-2) * (height-2) * 10 - MAX_ATTEMPTS_FROM_SCRATCH = 10 - - attempt_number = 0 - while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH: - cells_to_fill = [] - rail = [] - for r in range(height): - rail.append([None]*width) - if r > 0 and r < height-1: - cells_to_fill = cells_to_fill \ - + [(r, c) for c in range(1, width-1)] - - num_insertions = 0 - while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0: - cell = random.sample(cells_to_fill, 1)[0] - cells_to_fill.remove(cell) - row = cell[0] - col = cell[1] - - # look at its neighbors and see what are the possible transitions - # that can be chosen from, if any. - valid_template = [-1, -1, -1, -1] - - for el in [(0, 2, (-1, 0)), - (1, 3, (0, 1)), - (2, 0, (1, 0)), - (3, 1, (0, -1))]: # N, E, S, W - neigh_trans = rail[row+el[2][0]][col+el[2][1]] - if neigh_trans is not None: - # select transition coming from facing direction el[1] and - # moving to direction el[1] - max_bit = 0 - for k in range(4): - max_bit |= \ - t_utils.get_transition(neigh_trans, k, el[1]) - - if max_bit: - valid_template[el[0]] = 1 - else: - valid_template[el[0]] = 0 - - possible_cell_transitions = get_matching_templates(valid_template) - - if len(possible_cell_transitions) == 0: # NO VALID TRANSITIONS - # no cell can be filled in without violating some transitions - # can a dead-end solve the problem? - if valid_template.count(1) == 1: - for k in range(4): - if valid_template[k] == 1: - rot = 0 - if k == 0: - rot = 180 - elif k == 1: - rot = 270 - elif k == 2: + def generator(width, height, num_resets=0): + t_utils = RailEnvTransitions() + + transition_probability = cell_type_relative_proportion + + transitions_templates_ = [] + transition_probabilities = [] + for i in range(len(t_utils.transitions)-1): # don't include dead-ends + all_transitions = 0 + for dir_ in range(4): + trans = t_utils.get_transitions(t_utils.transitions[i], dir_) + all_transitions |= (trans[0] << 3) | \ + (trans[1] << 2) | \ + (trans[2] << 1) | \ + (trans[3]) + + template = [int(x) for x in bin(all_transitions)[2:]] + template = [0]*(4-len(template)) + template + + # add all rotations + for rot in [0, 90, 180, 270]: + transitions_templates_.append((template, + t_utils.rotate_transition( + t_utils.transitions[i], + rot))) + transition_probabilities.append(transition_probability[i]) + template = [template[-1]]+template[:-1] + + def get_matching_templates(template): + ret = [] + for i in range(len(transitions_templates_)): + is_match = True + for j in range(4): + if template[j] >= 0 and \ + template[j] != transitions_templates_[i][0][j]: + is_match = False + break + if is_match: + ret.append((transitions_templates_[i][1], transition_probabilities[i])) + return ret + + MAX_INSERTIONS = (width-2) * (height-2) * 10 + MAX_ATTEMPTS_FROM_SCRATCH = 10 + + attempt_number = 0 + while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH: + cells_to_fill = [] + rail = [] + for r in range(height): + rail.append([None]*width) + if r > 0 and r < height-1: + cells_to_fill = cells_to_fill \ + + [(r, c) for c in range(1, width-1)] + + num_insertions = 0 + while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0: + cell = random.sample(cells_to_fill, 1)[0] + cells_to_fill.remove(cell) + row = cell[0] + col = cell[1] + + # look at its neighbors and see what are the possible transitions + # that can be chosen from, if any. + valid_template = [-1, -1, -1, -1] + + for el in [(0, 2, (-1, 0)), + (1, 3, (0, 1)), + (2, 0, (1, 0)), + (3, 1, (0, -1))]: # N, E, S, W + neigh_trans = rail[row+el[2][0]][col+el[2][1]] + if neigh_trans is not None: + # select transition coming from facing direction el[1] and + # moving to direction el[1] + max_bit = 0 + for k in range(4): + max_bit |= \ + t_utils.get_transition(neigh_trans, k, el[1]) + + if max_bit: + valid_template[el[0]] = 1 + else: + valid_template[el[0]] = 0 + + possible_cell_transitions = get_matching_templates(valid_template) + + if len(possible_cell_transitions) == 0: # NO VALID TRANSITIONS + # no cell can be filled in without violating some transitions + # can a dead-end solve the problem? + if valid_template.count(1) == 1: + for k in range(4): + if valid_template[k] == 1: rot = 0 - elif k == 3: - rot = 90 + if k == 0: + rot = 180 + elif k == 1: + rot = 270 + elif k == 2: + rot = 0 + elif k == 3: + rot = 90 - rail[row][col] = t_utils.rotate_transition( - int('0000000000100000', 2), rot) + rail[row][col] = t_utils.rotate_transition( + int('0000000000100000', 2), rot) + num_insertions += 1 + + break + + else: + # can I get valid transitions by removing a single + # neighboring cell? + bestk = -1 + besttrans = [] + for k in range(4): + tmp_template = valid_template[:] + tmp_template[k] = -1 + possible_cell_transitions = get_matching_templates( + tmp_template) + if len(possible_cell_transitions) > len(besttrans): + besttrans = possible_cell_transitions + bestk = k + + if bestk >= 0: + # Replace the corresponding cell with None, append it + # to cells to fill, fill in a transition in the current + # cell. + replace_row = row - 1 + replace_col = col + if bestk == 1: + replace_row = row + replace_col = col + 1 + elif bestk == 2: + replace_row = row + 1 + replace_col = col + elif bestk == 3: + replace_row = row + replace_col = col - 1 + + cells_to_fill.append((replace_row, replace_col)) + rail[replace_row][replace_col] = None + + possible_transitions, possible_probabilities = zip(*besttrans) + possible_probabilities = \ + np.exp(possible_probabilities) / sum(np.exp(possible_probabilities)) + + rail[row][col] = np.random.choice(possible_transitions, + p=possible_probabilities) num_insertions += 1 - break + else: + print('WARNING: still nothing!') + rail[row][col] = int('0000000000000000', 2) + num_insertions += 1 + pass else: - # can I get valid transitions by removing a single - # neighboring cell? - bestk = -1 - besttrans = [] - for k in range(4): - tmp_template = valid_template[:] - tmp_template[k] = -1 - possible_cell_transitions = get_matching_templates( - tmp_template) - if len(possible_cell_transitions) > len(besttrans): - besttrans = possible_cell_transitions - bestk = k - - if bestk >= 0: - # Replace the corresponding cell with None, append it - # to cells to fill, fill in a transition in the current - # cell. - replace_row = row - 1 - replace_col = col - if bestk == 1: - replace_row = row - replace_col = col + 1 - elif bestk == 2: - replace_row = row + 1 - replace_col = col - elif bestk == 3: - replace_row = row - replace_col = col - 1 + possible_transitions, possible_probabilities = zip(*possible_cell_transitions) + possible_probabilities = np.exp(possible_probabilities) / sum(np.exp(possible_probabilities)) - cells_to_fill.append((replace_row, replace_col)) - rail[replace_row][replace_col] = None + rail[row][col] = np.random.choice(possible_transitions, + p=possible_probabilities) + num_insertions += 1 - rail[row][col] = random.sample( - besttrans, 1)[0] - num_insertions += 1 + if num_insertions == MAX_INSERTIONS: + # Failed to generate a valid level; try again for a number of times + attempt_number += 1 + else: + break - else: - print('WARNING: still nothing!') - rail[row][col] = int('0000000000000000', 2) - num_insertions += 1 - pass + if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH: + print('ERROR: failed to generate level') + # Finally pad the border of the map with dead-ends to avoid border issues; + # at most 1 transition in the neigh cell + for r in range(height): + # Check for transitions coming from [r][1] to WEST + max_bit = 0 + neigh_trans = rail[r][1] + if neigh_trans is not None: + for k in range(4): + neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ + & (2**4-1) + max_bit = max_bit | (neigh_trans_from_direction & 1) + if max_bit: + rail[r][0] = t_utils.rotate_transition( + int('0000000000100000', 2), 270) else: - rail[row][col] = random.sample( - possible_cell_transitions, 1)[0] - num_insertions += 1 - - if num_insertions == MAX_INSERTIONS: - # Failed to generate a valid level; try again for a number of times - attempt_number += 1 - else: - break - - if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH: - print('ERROR: failed to generate level') - - # Finally pad the border of the map with dead-ends to avoid border issues; - # at most 1 transition in the neigh cell - for r in range(height): - # Check for transitions coming from [r][1] to WEST - max_bit = 0 - neigh_trans = rail[r][1] - if neigh_trans is not None: - for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ - & (2**4-1) - max_bit = max_bit | (neigh_trans_from_direction & 1) - if max_bit: - rail[r][0] = t_utils.rotate_transition( - int('0000000000100000', 2), 270) - else: - rail[r][0] = int('0000000000000000', 2) - - # Check for transitions coming from [r][-2] to EAST - max_bit = 0 - neigh_trans = rail[r][-2] - if neigh_trans is not None: - for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ - & (2**4-1) - max_bit = max_bit | (neigh_trans_from_direction & (1 << 2)) - if max_bit: - rail[r][-1] = t_utils.rotate_transition(int('0000000000100000', 2), - 90) - else: - rail[r][-1] = int('0000000000000000', 2) - - for c in range(width): - # Check for transitions coming from [1][c] to NORTH - max_bit = 0 - neigh_trans = rail[1][c] - if neigh_trans is not None: - for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ - & (2**4-1) - max_bit = max_bit | (neigh_trans_from_direction & (1 << 3)) - if max_bit: - rail[0][c] = int('0000000000100000', 2) - else: - rail[0][c] = int('0000000000000000', 2) - - # Check for transitions coming from [-2][c] to SOUTH - max_bit = 0 - neigh_trans = rail[-2][c] - if neigh_trans is not None: - for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ - & (2**4-1) - max_bit = max_bit | (neigh_trans_from_direction & (1 << 1)) - if max_bit: - rail[-1][c] = t_utils.rotate_transition( - int('0000000000100000', 2), 180) - else: - rail[-1][c] = int('0000000000000000', 2) - - # For display only, wrong levels - for r in range(height): + rail[r][0] = int('0000000000000000', 2) + + # Check for transitions coming from [r][-2] to EAST + max_bit = 0 + neigh_trans = rail[r][-2] + if neigh_trans is not None: + for k in range(4): + neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ + & (2**4-1) + max_bit = max_bit | (neigh_trans_from_direction & (1 << 2)) + if max_bit: + rail[r][-1] = t_utils.rotate_transition(int('0000000000100000', 2), + 90) + else: + rail[r][-1] = int('0000000000000000', 2) + for c in range(width): - if rail[r][c] is None: - rail[r][c] = int('0000000000000000', 2) + # Check for transitions coming from [1][c] to NORTH + max_bit = 0 + neigh_trans = rail[1][c] + if neigh_trans is not None: + for k in range(4): + neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ + & (2**4-1) + max_bit = max_bit | (neigh_trans_from_direction & (1 << 3)) + if max_bit: + rail[0][c] = int('0000000000100000', 2) + else: + rail[0][c] = int('0000000000000000', 2) + + # Check for transitions coming from [-2][c] to SOUTH + max_bit = 0 + neigh_trans = rail[-2][c] + if neigh_trans is not None: + for k in range(4): + neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ + & (2**4-1) + max_bit = max_bit | (neigh_trans_from_direction & (1 << 1)) + if max_bit: + rail[-1][c] = t_utils.rotate_transition( + int('0000000000100000', 2), 180) + else: + rail[-1][c] = int('0000000000000000', 2) + + # For display only, wrong levels + for r in range(height): + for c in range(width): + if rail[r][c] is None: + rail[r][c] = int('0000000000000000', 2) + + tmp_rail = np.asarray(rail, dtype=np.uint16) + return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils) + return_rail.grid = tmp_rail + return return_rail - tmp_rail = np.asarray(rail, dtype=np.uint16) - return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils) - return_rail.grid = tmp_rail - return return_rail + return generator class RailEnv(Environment): -- GitLab