From e4c43e71fbb881c3a9df383a9e2bdc46de0d3187 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Fri, 26 Jul 2019 12:29:36 -0400 Subject: [PATCH] Added new observation and prediction builders to deviate from standard implementation in flatland --- torch_training/multi_agent_inference.py | 8 +- .../observation_builders/__init__.py | 0 .../observation_builders/observations.py | 567 ++++++++++++++++++ torch_training/predictors/__init__.py | 0 torch_training/predictors/predictions.py | 103 ++++ utils/observation_utils.py | 1 - 6 files changed, 674 insertions(+), 5 deletions(-) create mode 100644 torch_training/observation_builders/__init__.py create mode 100644 torch_training/observation_builders/observations.py create mode 100644 torch_training/predictors/__init__.py create mode 100644 torch_training/predictors/predictions.py diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py index 6a9ed8e..5cd8e07 100644 --- a/torch_training/multi_agent_inference.py +++ b/torch_training/multi_agent_inference.py @@ -4,8 +4,8 @@ from collections import deque import numpy as np import torch from flatland.envs.generators import rail_from_file, complex_rail_generator -from flatland.envs.observations import TreeObsForRailEnv -from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from observation_builders.observations import TreeObsForRailEnv +from predictors.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool from importlib_resources import path @@ -17,7 +17,7 @@ from utils.observation_utils import normalize_observation random.seed(3) np.random.seed(2) -file_name = "./railway/testing_stuff.pkl" +file_name = "./railway/simple_avoid.pkl" env = RailEnv(width=10, height=20, rail_generator=rail_from_file(file_name), @@ -94,7 +94,7 @@ for trials in range(1, n_trials + 1): if record_images: env_renderer.gl.save_image("./Images/Avoiding/flatland_frame_{:04d}.bmp".format(frame_step)) frame_step += 1 - time.sleep(1.5) + # time.sleep(1.5) # Action for a in range(env.get_num_agents()): action = agent.act(agent_obs[a], eps=0) diff --git a/torch_training/observation_builders/__init__.py b/torch_training/observation_builders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torch_training/observation_builders/observations.py b/torch_training/observation_builders/observations.py new file mode 100644 index 0000000..e3d52d3 --- /dev/null +++ b/torch_training/observation_builders/observations.py @@ -0,0 +1,567 @@ +""" +Collection of environment-specific ObservationBuilder. +""" +import pprint +from collections import deque + +import numpy as np + +from flatland.core.env_observation_builder import ObservationBuilder +from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.core.grid.grid_utils import coordinate_to_position + + +class TreeObsForRailEnv(ObservationBuilder): + """ + TreeObsForRailEnv object. + + This object returns observation vectors for agents in the RailEnv environment. + The information is local to each agent and exploits the graph structure of the rail + network to simplify the representation of the state of the environment for each agent. + + For details about the features in the tree observation see the get() function. + """ + + def __init__(self, max_depth, predictor=None): + super().__init__() + self.max_depth = max_depth + self.observation_dim = 9 + # Compute the size of the returned observation vector + size = 0 + pow4 = 1 + for i in range(self.max_depth + 1): + size += pow4 + pow4 *= 4 + self.observation_dim = 9 + self.observation_space = [size * self.observation_dim] + self.location_has_agent = {} + self.location_has_agent_direction = {} + self.predictor = predictor + self.agents_previous_reset = None + self.tree_explored_actions = [1, 2, 3, 0] + self.tree_explorted_actions_char = ['L', 'F', 'R', 'B'] + self.distance_map = None + self.distance_map_computed = False + + def reset(self): + agents = self.env.agents + nb_agents = len(agents) + compute_distance_map = True + if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset): + compute_distance_map = False + for i in range(nb_agents): + if agents[i].target != self.agents_previous_reset[i].target: + compute_distance_map = True + # Don't compute the distance map if it was loaded + if self.agents_previous_reset is None and self.distance_map is not None: + self.location_has_target = {tuple(agent.target): 1 for agent in agents} + compute_distance_map = False + + if compute_distance_map: + self._compute_distance_map() + + self.agents_previous_reset = agents + + def _compute_distance_map(self): + agents = self.env.agents + # For testing only --> To assert if a distance map need to be recomputed. + self.distance_map_computed = True + nb_agents = len(agents) + self.distance_map = np.inf * np.ones(shape=(nb_agents, + self.env.height, + self.env.width, + 4)) + self.max_dist = np.zeros(nb_agents) + self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)] + # Update local lookup table for all agents' target locations + self.location_has_target = {tuple(agent.target): 1 for agent in agents} + + def _distance_map_walker(self, position, target_nr): + """ + Utility function to compute distance maps from each cell in the rail network (and each possible + orientation within it) to each agent's target cell. + """ + # Returns max distance to target, from the farthest away node, while filling in distance_map + self.distance_map[target_nr, position[0], position[1], :] = 0 + + # Fill in the (up to) 4 neighboring nodes + # direction is the direction of movement, meaning that at least a possible orientation of an agent + # in cell (row,col) allows a movement in direction `direction' + nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1)) + + # BFS from target `position' to all the reachable nodes in the grid + # Stop the search if the target position is re-visited, in any direction + visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2), + (position[0], position[1], 3)} + + max_distance = 0 + + while nodes_queue: + node = nodes_queue.popleft() + + node_id = (node[0], node[1], node[2]) + + if node_id not in visited: + visited.add(node_id) + + # From the list of possible neighbors that have at least a path to the current node, only keep those + # whose new orientation in the current cell would allow a transition to direction node[2] + valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2]) + + for n in valid_neighbors: + nodes_queue.append(n) + + if len(valid_neighbors) > 0: + max_distance = max(max_distance, node[3] + 1) + + return max_distance + + def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1): + """ + Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the + minimum distances from each target cell. + """ + neighbors = [] + + possible_directions = [0, 1, 2, 3] + if enforce_target_direction >= 0: + # The agent must land into the current cell with orientation `enforce_target_direction'. + # This is only possible if the agent has arrived from the cell in the opposite direction! + possible_directions = [(enforce_target_direction + 2) % 4] + + for neigh_direction in possible_directions: + new_cell = self._new_position(position, neigh_direction) + + if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width: + + desired_movement_from_new_cell = (neigh_direction + 2) % 4 + + # Check all possible transitions in new_cell + for agent_orientation in range(4): + # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible? + is_valid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation), + desired_movement_from_new_cell) + + if is_valid: + """ + # TODO: check that it works with deadends! -- still bugged! + movement = desired_movement_from_new_cell + if isNextCellDeadEnd: + movement = (desired_movement_from_new_cell+2) % 4 + """ + new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation], + current_distance + 1) + neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance)) + self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance + + return neighbors + + def _new_position(self, position, movement): + """ + Utility function that converts a compass movement over a 2D grid to new positions (r, c). + """ + if movement == Grid4TransitionsEnum.NORTH: + return (position[0] - 1, position[1]) + elif movement == Grid4TransitionsEnum.EAST: + return (position[0], position[1] + 1) + elif movement == Grid4TransitionsEnum.SOUTH: + return (position[0] + 1, position[1]) + elif movement == Grid4TransitionsEnum.WEST: + return (position[0], position[1] - 1) + + def get_many(self, handles=None): + """ + Called whenever an observation has to be computed for the `env' environment, for each agent with handle + in the `handles' list. + """ + + if handles is None: + handles = [] + if self.predictor: + self.max_prediction_depth = 0 + self.predicted_pos = {} + self.predicted_dir = {} + self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map}) + if self.predictions: + + for t in range(len(self.predictions[0])): + pos_list = [] + dir_list = [] + for a in handles: + pos_list.append(self.predictions[a][t][1:3]) + dir_list.append(self.predictions[a][t][3]) + self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)}) + self.predicted_dir.update({t: dir_list}) + self.max_prediction_depth = len(self.predicted_pos) + observations = {} + for h in handles: + observations[h] = self.get(h) + return observations + + def get(self, handle): + """ + Computes the current observation for agent `handle' in env + + The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible + movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv). + The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for + the transitions. The order is: + [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back'] + + Each branch data is organized as: + [root node information] + + [recursive branch data from 'left'] + + [... from 'forward'] + + [... from 'right] + + [... from 'back'] + + Each node information is composed of 9 features: + + #1: if own target lies on the explored branch the current distance from the agent in number of cells is stored. + + #2: if another agents target is detected the distance in number of cells from the agents current locaiton + is stored + + #3: if another agent is detected the distance in number of cells from current agent position is stored. + + #4: possible conflict detected + tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the + distance in number of cells from current agent position + + 0 = No other agent reserve the same cell at similar time + + #5: if an not usable switch (for agent) is detected we store the distance. + + #6: This feature stores the distance in number of cells to the next branching (current node) + + #7: minimum distance from node to the agent's target given the direction of the agent if this path is chosen + + #8: agent in the same direction + n = number of agents present same direction + (possible future use: number of other agents in the same direction in this branch) + 0 = no agent present same direction + + #9: agent in the opposite direction + n = number of agents present other direction than myself (so conflict) + (possible future use: number of other agents in other direction in this branch, ie. number of conflicts) + 0 = no agent present other direction than myself + + + + + Missing/padding nodes are filled in with -inf (truncated). + Missing values in present node are filled in with +inf (truncated). + + + In case of the root node, the values are [0, 0, 0, 0, distance from agent to target]. + In case the target node is reached, the values are [0, 0, 0, 0, 0]. + """ + + # Update local lookup table for all agents' positions + self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents} + self.location_has_agent_direction = {tuple(agent.position): agent.direction for agent in self.env.agents} + if handle > len(self.env.agents): + print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents)) + agent = self.env.agents[handle] # TODO: handle being treated as index + possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) + num_transitions = np.count_nonzero(possible_transitions) + + # Root node - current position + observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0] + + visited = set() + # Start from the current orientation, and see which transitions are available; + # organize them as [left, forward, right, back], relative to the current orientation + # If only one transition is possible, the tree is oriented with this transition as the forward branch. + orientation = agent.direction + + if num_transitions == 1: + orientation = np.argmax(possible_transitions) + + for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]: + if possible_transitions[branch_direction]: + new_cell = self._new_position(agent.position, branch_direction) + branch_observation, branch_visited = \ + self._explore_branch(handle, new_cell, branch_direction, 1, 1) + observation = observation + branch_observation + visited = visited.union(branch_visited) + else: + # add cells filled with infinity if no transition is possible + observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth) + self.env.dev_obs_dict[handle] = visited + return observation + + def _num_cells_to_fill_in(self, remaining_depth): + """Computes the length of observation vector: sum_{i=0,depth-1} 2^i * observation_dim.""" + num_observations = 0 + pow4 = 1 + for i in range(remaining_depth): + num_observations += pow4 + pow4 *= 4 + return num_observations * self.observation_dim + + def _explore_branch(self, handle, position, direction, tot_dist, depth): + """ + Utility function to compute tree-based observations. + We walk along the branch and collect the information documented in the get() function. + If there is a branching point a new node is created and each possible branch is explored. + """ + # [Recursive branch opened] + if depth >= self.max_depth + 1: + return [], [] + + # Continue along direction until next switch or + # until no transitions are possible along the current direction (i.e., dead-ends) + # We treat dead-ends as nodes, instead of going back, to avoid loops + exploring = True + last_is_switch = False + last_is_dead_end = False + last_is_terminal = False # wrong cell OR cycle; either way, we don't want the agent to land here + last_is_target = False + + visited = set() + agent = self.env.agents[handle] + own_target_encountered = np.inf + other_agent_encountered = np.inf + other_target_encountered = np.inf + potential_conflict = np.inf + unusable_switch = np.inf + other_agent_same_direction = 0 + other_agent_opposite_direction = 0 + + num_steps = 1 + while exploring: + # ############################# + # ############################# + # Modify here to compute any useful data required to build the end node's features. This code is called + # for each cell visited between the previous branching node and the next switch / target / dead-end. + if position in self.location_has_agent: + if tot_dist < other_agent_encountered: + other_agent_encountered = tot_dist + + if self.location_has_agent_direction[position] == direction: + # Cummulate the number of agents on branch with same direction + other_agent_same_direction += 1 + + if self.location_has_agent_direction[position] != direction: + # Cummulate the number of agents on branch with other direction + other_agent_opposite_direction += 1 + + # Check number of possible transitions for agent and total number of transitions in cell (type) + cell_transitions = self.env.rail.get_transitions(*position, direction) + transition_bit = bin(self.env.rail.get_full_transitions(*position)) + total_transitions = transition_bit.count("1") + crossing_found = False + if int(transition_bit, 2) == int('1000010000100001', 2): + crossing_found = True + + # Register possible future conflict + if self.predictor and num_steps < self.max_prediction_depth: + int_position = coordinate_to_position(self.env.width, [position]) + if tot_dist < self.max_prediction_depth: + pre_step = max(0, tot_dist - 1) + post_step = min(self.max_prediction_depth - 1, tot_dist + 1) + + # Look for conflicting paths at distance num_step + if int_position in np.delete(self.predicted_pos[tot_dist], handle, 0): + conflicting_agent = np.where(self.predicted_pos[tot_dist] == int_position) + for ca in conflicting_agent[0]: + if direction != self.predicted_dir[tot_dist][ca] and tot_dist < potential_conflict: + potential_conflict = tot_dist + if self.env.dones[ca] and tot_dist < potential_conflict: + potential_conflict = tot_dist + + # Look for conflicting paths at distance num_step-1 + elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0): + conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position) + for ca in conflicting_agent[0]: + if direction != self.predicted_dir[pre_step][ca] and tot_dist < potential_conflict: + potential_conflict = tot_dist + if self.env.dones[ca] and tot_dist < potential_conflict: + potential_conflict = tot_dist + + # Look for conflicting paths at distance num_step+1 + elif int_position in np.delete(self.predicted_pos[post_step], handle, 0): + conflicting_agent = np.where(self.predicted_pos[post_step] == int_position) + for ca in conflicting_agent[0]: + if direction != self.predicted_dir[post_step][ca] and tot_dist < potential_conflict: + potential_conflict = tot_dist + if self.env.dones[ca] and tot_dist < potential_conflict: + potential_conflict = tot_dist + + if position in self.location_has_target and position != agent.target: + if tot_dist < other_target_encountered: + other_target_encountered = tot_dist + + if position == agent.target and tot_dist < own_target_encountered: + own_target_encountered = tot_dist + + # ############################# + # ############################# + if (position[0], position[1], direction) in visited: + last_is_terminal = True + break + visited.add((position[0], position[1], direction)) + + # If the target node is encountered, pick that as node. Also, no further branching is possible. + if np.array_equal(position, self.env.agents[handle].target): + last_is_target = True + break + + # Check if crossing is found --> Not an unusable switch + if crossing_found: + # Treat the crossing as a straight rail cell + total_transitions = 2 + num_transitions = np.count_nonzero(cell_transitions) + + exploring = False + + # Detect Switches that can only be used by other agents. + if total_transitions > 2 > num_transitions and tot_dist < unusable_switch: + unusable_switch = tot_dist + + if num_transitions == 1: + # Check if dead-end, or if we can go forward along direction + nbits = total_transitions + if nbits == 1: + # Dead-end! + last_is_dead_end = True + + if not last_is_dead_end: + # Keep walking through the tree along `direction' + exploring = True + # convert one-hot encoding to 0,1,2,3 + direction = np.argmax(cell_transitions) + position = self._new_position(position, direction) + num_steps += 1 + tot_dist += 1 + elif num_transitions > 0: + # Switch detected + last_is_switch = True + break + + elif num_transitions == 0: + # Wrong cell type, but let's cover it and treat it as a dead-end, just in case + print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0], + position[1], direction) + last_is_terminal = True + break + + # `position' is either a terminal node or a switch + + # ############################# + # ############################# + # Modify here to append new / different features for each visited cell! + + if last_is_target: + observation = [own_target_encountered, + other_target_encountered, + other_agent_encountered, + potential_conflict, + unusable_switch, + tot_dist, + 0, + other_agent_same_direction, + other_agent_opposite_direction + ] + + elif last_is_terminal: + observation = [own_target_encountered, + other_target_encountered, + other_agent_encountered, + potential_conflict, + unusable_switch, + np.inf, + self.distance_map[handle, position[0], position[1], direction], + other_agent_same_direction, + other_agent_opposite_direction + ] + + else: + observation = [own_target_encountered, + other_target_encountered, + other_agent_encountered, + potential_conflict, + unusable_switch, + tot_dist, + self.distance_map[handle, position[0], position[1], direction], + other_agent_same_direction, + other_agent_opposite_direction, + ] + # ############################# + # ############################# + # Start from the current orientation, and see which transitions are available; + # organize them as [left, forward, right, back], relative to the current orientation + # Get the possible transitions + possible_transitions = self.env.rail.get_transitions(*position, direction) + for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]: + if last_is_dead_end and self.env.rail.get_transition((*position, direction), + (branch_direction + 2) % 4): + # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes + # it back + new_cell = self._new_position(position, (branch_direction + 2) % 4) + branch_observation, branch_visited = self._explore_branch(handle, + new_cell, + (branch_direction + 2) % 4, + tot_dist + 1, + depth + 1) + observation = observation + branch_observation + if len(branch_visited) != 0: + visited = visited.union(branch_visited) + elif last_is_switch and possible_transitions[branch_direction]: + new_cell = self._new_position(position, branch_direction) + branch_observation, branch_visited = self._explore_branch(handle, + new_cell, + branch_direction, + tot_dist + 1, + depth + 1) + observation = observation + branch_observation + if len(branch_visited) != 0: + visited = visited.union(branch_visited) + else: + # no exploring possible, add just cells with infinity + observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth - depth) + + return observation, visited + + def util_print_obs_subtree(self, tree): + """ + Utility function to pretty-print tree observations returned by this object. + """ + pp = pprint.PrettyPrinter(indent=4) + pp.pprint(self.unfold_observation_tree(tree)) + + def unfold_observation_tree(self, tree, current_depth=0, actions_for_display=True): + """ + Utility function to pretty-print tree observations returned by this object. + """ + if len(tree) < self.observation_dim: + return + + depth = 0 + tmp = len(tree) / self.observation_dim - 1 + pow4 = 4 + while tmp > 0: + tmp -= pow4 + depth += 1 + pow4 *= 4 + + unfolded = {} + unfolded[''] = tree[0:self.observation_dim] + child_size = (len(tree) - self.observation_dim) // 4 + for child in range(4): + child_tree = tree[(self.observation_dim + child * child_size): + (self.observation_dim + (child + 1) * child_size)] + observation_tree = self.unfold_observation_tree(child_tree, current_depth=current_depth + 1) + if observation_tree is not None: + if actions_for_display: + label = self.tree_explorted_actions_char[child] + else: + label = self.tree_explored_actions[child] + unfolded[label] = observation_tree + return unfolded + + def _set_env(self, env): + self.env = env + if self.predictor: + self.predictor._set_env(self.env) diff --git a/torch_training/predictors/__init__.py b/torch_training/predictors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torch_training/predictors/predictions.py b/torch_training/predictors/predictions.py new file mode 100644 index 0000000..b3836ca --- /dev/null +++ b/torch_training/predictors/predictions.py @@ -0,0 +1,103 @@ +""" +Collection of environment-specific PredictionBuilder. +""" + +import numpy as np + +from flatland.core.env_prediction_builder import PredictionBuilder +from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.rail_env import RailEnvActions + + +class ShortestPathPredictorForRailEnv(PredictionBuilder): + """ + ShortestPathPredictorForRailEnv object. + + This object returns shortest-path predictions for agents in the RailEnv environment. + The prediction acts as if no other agent is in the environment and always takes the forward action. + """ + + def get(self, custom_args=None, handle=None): + """ + Called whenever get_many in the observation build is called. + Requires distance_map to extract the shortest path. + + Parameters + ------- + custom_args: dict + - distance_map : dict + handle : int (optional) + Handle of the agent for which to compute the observation vector. + + Returns + ------- + np.array + Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements: + - time_offset + - position axis 0 + - position axis 1 + - direction + - action taken to come here + The prediction at 0 is the current position, direction etc. + """ + agents = self.env.agents + if handle: + agents = [self.env.agents[handle]] + assert custom_args is not None + distance_map = custom_args.get('distance_map') + assert distance_map is not None + + prediction_dict = {} + for agent in agents: + _agent_initial_position = agent.position + _agent_initial_direction = agent.direction + prediction = np.zeros(shape=(self.max_depth + 1, 5)) + prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0] + visited = set() + for index in range(1, self.max_depth + 1): + # if we're at the target, stop moving... + if agent.position == agent.target: + prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING] + visited.add((agent.position[0], agent.position[1], agent.direction)) + continue + if not agent.moving: + prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING] + visited.add((agent.position[0], agent.position[1], agent.direction)) + continue + # Take shortest possible path + cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) + + new_position = None + new_direction = None + if np.sum(cell_transitions) == 1: + new_direction = np.argmax(cell_transitions) + new_position = get_new_position(agent.position, new_direction) + elif np.sum(cell_transitions) > 1: + min_dist = np.inf + no_dist_found = True + for direction in range(4): + if cell_transitions[direction] == 1: + neighbour_cell = get_new_position(agent.position, direction) + target_dist = distance_map[agent.handle, neighbour_cell[0], neighbour_cell[1], direction] + if target_dist < min_dist or no_dist_found: + min_dist = target_dist + new_direction = direction + no_dist_found = False + new_position = get_new_position(agent.position, new_direction) + else: + raise Exception("No transition possible {}".format(cell_transitions)) + + # update the agent's position and direction + agent.position = new_position + agent.direction = new_direction + + # prediction is ready + prediction[index] = [index, *new_position, new_direction, 0] + visited.add((new_position[0], new_position[1], new_direction)) + self.env.dev_pred_dict[agent.handle] = visited + prediction_dict[agent.handle] = prediction + + # cleanup: reset initial position + agent.position = _agent_initial_position + agent.direction = _agent_initial_direction + return prediction_dict diff --git a/utils/observation_utils.py b/utils/observation_utils.py index 26108cc..7891c28 100644 --- a/utils/observation_utils.py +++ b/utils/observation_utils.py @@ -97,7 +97,6 @@ def split_tree(tree, num_features_per_node, current_depth=0): agent_data.extend(tmp_agent_data) return tree_data, distance_data, agent_data - def normalize_observation(observation, num_features_per_node=9, observation_radius=0): data, distance, agent_data = split_tree(tree=np.array(observation), num_features_per_node=num_features_per_node, current_depth=0) -- GitLab