diff --git a/examples/temporary_example.py b/examples/temporary_example.py index c015f6140617c31fa020bc9a73dcdb3c9c55cc3e..97d41b536ee9f71f233c8608f937b279e5763cea 100644 --- a/examples/temporary_example.py +++ b/examples/temporary_example.py @@ -6,8 +6,8 @@ from flatland.envs.rail_env import * from flatland.core.env_observation_builder import TreeObsForRailEnv from flatland.utils.rendertools import * -random.seed(1) -np.random.seed(1) +random.seed(0) +np.random.seed(0) """ transition_probability = [1.0, # empty cell - Case 0 @@ -27,7 +27,7 @@ transition_probability = [1.0, # empty cell - Case 0 0.1, # Case 5 - double slip 0.2, # Case 6 - symmetrical 0.01] # Case 7 - dead end - +""" # Example generate a random rail env = RailEnv(width=20, height=20, @@ -38,7 +38,6 @@ env.reset() env_renderer = RenderTool(env) env_renderer.renderEnv(show=True) - # Example generate a rail given a manual specification, # a map of tuples (cell_type, rotation) specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)], @@ -51,19 +50,25 @@ env = RailEnv(width=6, obs_builder_object=TreeObsForRailEnv(max_depth=2)) handle = env.get_agent_handles() - env.agents_position[0] = [1, 4] env.agents_target[0] = [1, 1] env.agents_direction[0] = 1 # TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too! env.obs_builder.reset() +""" + +env = RailEnv(width=7, + height=7, + rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), + number_of_agents=1) # TODO: delete next line -#for i in range(4): -# print(env.obs_builder.distance_map[0, :, :, i]) +for i in range(4): + print(env.obs_builder.distance_map[0, :, :, i]) obs, all_rewards, done, _ = env.step({0:0}) -env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) +for i in range(env.number_of_agents): + env.obs_builder.util_print_obs_subtree(tree=obs[i], num_elements_per_node=5) env_renderer = RenderTool(env) env_renderer.renderEnv(show=True) diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index cd8a53094471a8e48e158c301f45d148eebdecce..a0def07ebbf7e598ba04de6094ba08b93ff06590 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -103,7 +103,6 @@ class TreeObsForRailEnv(ObservationBuilder): node = nodes_queue.popleft() node_id = (node[0], node[1], node[2]) - if node_id not in visited: visited.add(node_id) @@ -126,58 +125,50 @@ class TreeObsForRailEnv(ObservationBuilder): """ neighbors = [] - for direction in range(4): - new_cell = self._new_position(position, (direction+2) % 4) + possible_directions = [0, 1, 2, 3] + if enforce_target_direction >= 0: + # The agent must land into the current cell with orientation `enforce_target_direction'. + # This is only possible if the agent has arrived from the cell in the opposite direction! + possible_directions = [(enforce_target_direction+2) % 4] + + for neigh_direction in possible_directions: + new_cell = self._new_position(position, neigh_direction) if new_cell[0] >= 0 and new_cell[0] < self.env.height and \ new_cell[1] >= 0 and new_cell[1] < self.env.width: - # Check if the two cells are connected by a valid transition - transitionValid = False - for orientation in range(4): - moves = self.env.rail.get_transitions((new_cell[0], new_cell[1], orientation)) - if moves[direction]: - transitionValid = True - break - - if not transitionValid: - continue - - # Check if a transition in direction node[2] is possible if an agent lands in the current - # cell with orientation `direction'; this only applies to cells that are not dead-ends! - directionMatch = True - if enforce_target_direction >= 0: - directionMatch = self.env.rail.get_transition((new_cell[0], new_cell[1], direction), - enforce_target_direction) - - # If transition is found to invalid, check if perhaps it is a dead-end, in which case the - # direction of movement is rotated 180 degrees (moving forward turns the agents and makes - # it step in the previous cell) - if not directionMatch: - # If cell is a dead-end, append previous node with reversed - # orientation! - nbits = 0 - tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1])) - while tmp > 0: - nbits += (tmp & 1) - tmp = tmp >> 1 - if nbits == 1: - # Dead-end! - # Check if transition is possible in new_cell with orientation - # (direction+2)%4 in direction `direction' - directionMatch = directionMatch or \ - self.env.rail.get_transition((new_cell[0], new_cell[1], (direction+2) % 4), - direction) - - if transitionValid and directionMatch: - # Append all possible orientations in new_cell that allow a transition to direction! - for orientation in range(4): - moves = self.env.rail.get_transitions((new_cell[0], new_cell[1], orientation)) - if moves[direction]: - new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], orientation], - current_distance+1) - neighbors.append((new_cell[0], new_cell[1], orientation, new_distance)) - self.distance_map[target_nr, new_cell[0], new_cell[1], orientation] = new_distance + desired_movement_from_new_cell = (neigh_direction+2) % 4 + + """ + # Is the next cell a dead-end? + isNextCellDeadEnd = False + nbits = 0 + tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1])) + while tmp > 0: + nbits += (tmp & 1) + tmp = tmp >> 1 + if nbits == 1: + # Dead-end! + isNextCellDeadEnd = True + """ + + # Check all possible transitions in new_cell + for agent_orientation in range(4): + # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible? + isValid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation), + desired_movement_from_new_cell) + + if isValid: + """ + # TODO: check that it works with deadends! -- still bugged! + movement = desired_movement_from_new_cell + if isNextCellDeadEnd: + movement = (desired_movement_from_new_cell+2) % 4 + """ + new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation], + current_distance+1) + neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance)) + self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance return neighbors @@ -309,16 +300,24 @@ class TreeObsForRailEnv(ObservationBuilder): exploring = False if num_transitions == 1: # Check if dead-end, or if we can go forward along direction - if cell_transitions[direction]: - position = self._new_position(position, direction) + nbits = 0 + tmp = self.env.rail.get_transitions((position[0], position[1])) + while tmp > 0: + nbits += (tmp & 1) + tmp = tmp >> 1 + if nbits == 1: + # Dead-end! + last_isDeadEnd = True + if not last_isDeadEnd: # Keep walking through the tree along `direction' exploring = True - else: - # If a dead-end is reached, pick that as node. Also, no further branching is possible. - last_isDeadEnd = True - break + for i in range(4): + if cell_transitions[i]: + position = self._new_position(position, i) + direction = i + break elif num_transitions > 0: # Switch detected @@ -352,8 +351,6 @@ class TreeObsForRailEnv(ObservationBuilder): 0, self.distance_map[handle, position[0], position[1], direction]] - # TODO: - # ############################# # ############################# diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 31c91b571353d1b0f05826985e49ec26151d59c7..8083f1fa1459001cf577a3bd1fb5b0aec6a57e2d 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -218,7 +218,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): rot = 90 rail[row][col] = t_utils.rotate_transition( - int('0000000000100000', 2), rot) + int('0010000000000000', 2), rot) num_insertions += 1 break @@ -299,7 +299,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): max_bit = max_bit | (neigh_trans_from_direction & 1) if max_bit: rail[r][0] = t_utils.rotate_transition( - int('0000000000100000', 2), 270) + int('0010000000000000', 2), 270) else: rail[r][0] = int('0000000000000000', 2) @@ -312,7 +312,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): & (2**4-1) max_bit = max_bit | (neigh_trans_from_direction & (1 << 2)) if max_bit: - rail[r][-1] = t_utils.rotate_transition(int('0000000000100000', 2), + rail[r][-1] = t_utils.rotate_transition(int('0010000000000000', 2), 90) else: rail[r][-1] = int('0000000000000000', 2) @@ -327,7 +327,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): & (2**4-1) max_bit = max_bit | (neigh_trans_from_direction & (1 << 3)) if max_bit: - rail[0][c] = int('0000000000100000', 2) + rail[0][c] = int('0010000000000000', 2) else: rail[0][c] = int('0000000000000000', 2) @@ -341,7 +341,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): max_bit = max_bit | (neigh_trans_from_direction & (1 << 1)) if max_bit: rail[-1][c] = t_utils.rotate_transition( - int('0000000000100000', 2), 180) + int('0010000000000000', 2), 180) else: rail[-1][c] = int('0000000000000000', 2) @@ -352,6 +352,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): rail[r][c] = int('0000000000000000', 2) tmp_rail = np.asarray(rail, dtype=np.uint16) + return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils) return_rail.grid = tmp_rail return return_rail