From 81513d7b1d15bda2c4f8cb463df7a792abad8018 Mon Sep 17 00:00:00 2001 From: u229589 <christian.baumberger@sbb.ch> Date: Thu, 26 Sep 2019 10:01:58 +0200 Subject: [PATCH] remove unused observation and prediction files --- scoring/score_test.py | 2 +- torch_training/dueling_double_dqn.py | 13 +- torch_training/model.py | 35 - torch_training/multi_agent_inference.py | 6 +- torch_training/multi_agent_training.py | 2 +- .../multi_agent_two_time_step_training.py | 7 +- .../observation_builders/observations.py | 865 ------------------ torch_training/predictors/predictions.py | 179 ---- torch_training/render_agent_behavior.py | 2 +- torch_training/training_navigation.py | 2 +- 10 files changed, 15 insertions(+), 1098 deletions(-) delete mode 100644 torch_training/observation_builders/observations.py delete mode 100644 torch_training/predictors/predictions.py diff --git a/scoring/score_test.py b/scoring/score_test.py index 5665d44..fa65f4c 100644 --- a/scoring/score_test.py +++ b/scoring/score_test.py @@ -28,7 +28,7 @@ test_dones = [] sequential_agent_test = False # Load your agent -agent = Agent(state_size, action_size, "FC", 0) +agent = Agent(state_size, action_size, 0) agent.qnetwork_local.load_state_dict(torch.load('../torch_training/Nets/avoid_checkpoint60000.pth')) # Load the necessary Observation Builder and Predictor diff --git a/torch_training/dueling_double_dqn.py b/torch_training/dueling_double_dqn.py index cf2f7d5..5dfa5c1 100644 --- a/torch_training/dueling_double_dqn.py +++ b/torch_training/dueling_double_dqn.py @@ -8,7 +8,7 @@ import torch import torch.nn.functional as F import torch.optim as optim -from torch_training.model import QNetwork, QNetwork2 +from torch_training.model import QNetwork BUFFER_SIZE = int(1e5) # replay buffer size BATCH_SIZE = 512 # minibatch size @@ -27,7 +27,7 @@ print(device) class Agent: """Interacts with and learns from the environment.""" - def __init__(self, state_size, action_size, net_type, seed, double_dqn=True, input_channels=5): + def __init__(self, state_size, action_size, seed, double_dqn=True, input_channels=5): """Initialize an Agent object. Params @@ -39,15 +39,10 @@ class Agent: self.state_size = state_size self.action_size = action_size self.seed = random.seed(seed) - self.version = net_type self.double_dqn = double_dqn # Q-Network - if self.version == "Conv": - self.qnetwork_local = QNetwork2(state_size, action_size, seed, input_channels).to(device) - self.qnetwork_target = copy.deepcopy(self.qnetwork_local) - else: - self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device) - self.qnetwork_target = copy.deepcopy(self.qnetwork_local) + self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device) + self.qnetwork_target = copy.deepcopy(self.qnetwork_local) self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR) diff --git a/torch_training/model.py b/torch_training/model.py index 7a5b3d6..f7bc3d5 100644 --- a/torch_training/model.py +++ b/torch_training/model.py @@ -24,38 +24,3 @@ class QNetwork(nn.Module): adv = F.relu(self.fc2_adv(adv)) adv = self.fc3_adv(adv) return val + adv - adv.mean() - - -class QNetwork2(nn.Module): - def __init__(self, state_size, action_size, seed, input_channels, hidsize1=128, hidsize2=64): - super(QNetwork2, self).__init__() - self.conv1 = nn.Conv2d(input_channels, 16, kernel_size=3, stride=1) - self.bn1 = nn.BatchNorm2d(16) - self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=3) - self.bn2 = nn.BatchNorm2d(32) - self.conv3 = nn.Conv2d(32, 64, kernel_size=5, stride=3) - self.bn3 = nn.BatchNorm2d(64) - - self.fc1_val = nn.Linear(6400, hidsize1) - self.fc2_val = nn.Linear(hidsize1, hidsize2) - self.fc3_val = nn.Linear(hidsize2, 1) - - self.fc1_adv = nn.Linear(6400, hidsize1) - self.fc2_adv = nn.Linear(hidsize1, hidsize2) - self.fc3_adv = nn.Linear(hidsize2, action_size) - - def forward(self, x): - x = F.relu(self.conv1(x)) - x = F.relu(self.conv2(x)) - x = F.relu(self.conv3(x)) - - # value function approximation - val = F.relu(self.fc1_val(x.view(x.size(0), -1))) - val = F.relu(self.fc2_val(val)) - val = self.fc3_val(val) - - # advantage calculation - adv = F.relu(self.fc1_adv(x.view(x.size(0), -1))) - adv = F.relu(self.fc2_adv(adv)) - adv = self.fc3_adv(adv) - return val + adv - adv.mean() diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py index 2fbbe61..96c8ee7 100644 --- a/torch_training/multi_agent_inference.py +++ b/torch_training/multi_agent_inference.py @@ -4,8 +4,8 @@ from collections import deque import numpy as np import torch from importlib_resources import path -from observation_builders.observations import TreeObsForRailEnv -from predictors.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv import torch_training.Nets from flatland.envs.rail_env import RailEnv @@ -87,7 +87,7 @@ dones_list = [] action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() -agent = Agent(state_size, action_size, "FC", 0) +agent = Agent(state_size, action_size, 0) with path(torch_training.Nets, "avoid_checkpoint100.pth") as file_in: agent.qnetwork_local.load_state_dict(torch.load(file_in)) diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index 222430d..6747460 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -121,7 +121,7 @@ def main(argv): observation_radius = 10 # Initialize the agent - agent = Agent(state_size, action_size, "FC", 0) + agent = Agent(state_size, action_size, 0) # Here you can pre-load an agent if False: diff --git a/torch_training/multi_agent_two_time_step_training.py b/torch_training/multi_agent_two_time_step_training.py index 08cd84c..2687cd5 100644 --- a/torch_training/multi_agent_two_time_step_training.py +++ b/torch_training/multi_agent_two_time_step_training.py @@ -43,10 +43,10 @@ def main(argv): min_dist = int(0.75 * min(x_dim, y_dim)) tree_depth = 3 print("main2") + demo = False # Get an observation builder and predictor - predictor = ShortestPathPredictorForRailEnv() - observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=predictor()) + observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv()) env = RailEnv(width=x_dim, height=y_dim, @@ -87,7 +87,7 @@ def main(argv): agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() # Initialize the agent - agent = Agent(state_size, action_size, "FC", 0) + agent = Agent(state_size, action_size, 0) # Here you can pre-load an agent if False: @@ -132,6 +132,7 @@ def main(argv): # Build agent specific observations for a in range(env.get_num_agents()): data, distance, agent_data = split_tree(tree=np.array(obs[a]), + num_features_per_node=11, current_depth=0) data = norm_obs_clip(data) distance = norm_obs_clip(distance) diff --git a/torch_training/observation_builders/observations.py b/torch_training/observation_builders/observations.py deleted file mode 100644 index 66c3830..0000000 --- a/torch_training/observation_builders/observations.py +++ /dev/null @@ -1,865 +0,0 @@ -""" -Collection of environment-specific ObservationBuilder. -""" -import pprint -from collections import deque - -import numpy as np - -from flatland.core.env_observation_builder import ObservationBuilder -from flatland.core.grid.grid4_utils import get_new_position -from flatland.core.grid.grid_utils import coordinate_to_position - - -class TreeObsForRailEnv(ObservationBuilder): - """ - TreeObsForRailEnv object. - - This object returns observation vectors for agents in the RailEnv environment. - The information is local to each agent and exploits the graph structure of the rail - network to simplify the representation of the state of the environment for each agent. - - For details about the features in the tree observation see the get() function. - """ - - def __init__(self, max_depth, predictor=None): - super().__init__() - self.max_depth = max_depth - self.observation_dim = 11 - # Compute the size of the returned observation vector - size = 0 - pow4 = 1 - for i in range(self.max_depth + 1): - size += pow4 - pow4 *= 4 - self.observation_space = [size * self.observation_dim] - self.location_has_agent = {} - self.location_has_agent_direction = {} - self.predictor = predictor - self.agents_previous_reset = None - self.tree_explored_actions = [1, 2, 3, 0] - self.tree_explorted_actions_char = ['L', 'F', 'R', 'B'] - self.distance_map = None - self.distance_map_computed = False - - def reset(self): - agents = self.env.agents - nb_agents = len(agents) - compute_distance_map = True - if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset): - compute_distance_map = False - for i in range(nb_agents): - if agents[i].target != self.agents_previous_reset[i].target: - compute_distance_map = True - # Don't compute the distance map if it was loaded - if self.agents_previous_reset is None and self.distance_map is not None: - self.location_has_target = {tuple(agent.target): 1 for agent in agents} - compute_distance_map = False - - if compute_distance_map: - self._compute_distance_map() - - self.agents_previous_reset = agents - - def _compute_distance_map(self): - agents = self.env.agents - # For testing only --> To assert if a distance map need to be recomputed. - self.distance_map_computed = True - nb_agents = len(agents) - self.distance_map = np.inf * np.ones(shape=(nb_agents, - self.env.height, - self.env.width, - 4)) - self.max_dist = np.zeros(nb_agents) - self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)] - # Update local lookup table for all agents' target locations - self.location_has_target = {tuple(agent.target): 1 for agent in agents} - - def _distance_map_walker(self, position, target_nr): - """ - Utility function to compute distance maps from each cell in the rail network (and each possible - orientation within it) to each agent's target cell. - """ - # Returns max distance to target, from the farthest away node, while filling in distance_map - self.distance_map[target_nr, position[0], position[1], :] = 0 - - # Fill in the (up to) 4 neighboring nodes - # direction is the direction of movement, meaning that at least a possible orientation of an agent - # in cell (row,col) allows a movement in direction `direction` - nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1)) - - # BFS from target `position` to all the reachable nodes in the grid - # Stop the search if the target position is re-visited, in any direction - visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2), - (position[0], position[1], 3)} - - max_distance = 0 - - while nodes_queue: - node = nodes_queue.popleft() - - node_id = (node[0], node[1], node[2]) - - if node_id not in visited: - visited.add(node_id) - - # From the list of possible neighbors that have at least a path to the current node, only keep those - # whose new orientation in the current cell would allow a transition to direction node[2] - valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2]) - - for n in valid_neighbors: - nodes_queue.append(n) - - if len(valid_neighbors) > 0: - max_distance = max(max_distance, node[3] + 1) - - return max_distance - - def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1): - """ - Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the - minimum distances from each target cell. - """ - neighbors = [] - - possible_directions = [0, 1, 2, 3] - if enforce_target_direction >= 0: - # The agent must land into the current cell with orientation `enforce_target_direction`. - # This is only possible if the agent has arrived from the cell in the opposite direction! - possible_directions = [(enforce_target_direction + 2) % 4] - - for neigh_direction in possible_directions: - new_cell = get_new_position(position, neigh_direction) - - if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width: - - desired_movement_from_new_cell = (neigh_direction + 2) % 4 - - # Check all possible transitions in new_cell - for agent_orientation in range(4): - # Is a transition along movement `desired_movement_from_new_cell` to the current cell possible? - is_valid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation), - desired_movement_from_new_cell) - - if is_valid: - """ - # TODO: check that it works with deadends! -- still bugged! - movement = desired_movement_from_new_cell - if isNextCellDeadEnd: - movement = (desired_movement_from_new_cell+2) % 4 - """ - new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation], - current_distance + 1) - neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance)) - self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance - - return neighbors - - def get_many(self, handles=None): - """ - Called whenever an observation has to be computed for the `env` environment, for each agent with handle - in the `handles` list. - """ - - if handles is None: - handles = [] - if self.predictor: - self.max_prediction_depth = 0 - self.predicted_pos = {} - self.predicted_dir = {} - self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map}) - if self.predictions: - - for t in range(len(self.predictions[0])): - pos_list = [] - dir_list = [] - for a in handles: - pos_list.append(self.predictions[a][t][1:3]) - dir_list.append(self.predictions[a][t][3]) - self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)}) - self.predicted_dir.update({t: dir_list}) - self.max_prediction_depth = len(self.predicted_pos) - observations = {} - for h in handles: - observations[h] = self.get(h) - return observations - - def get(self, handle): - """ - Computes the current observation for agent `handle` in env - - The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible - movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv). - The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for - the transitions. The order is: - [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back'] - - Each branch data is organized as: - [root node information] + - [recursive branch data from 'left'] + - [... from 'forward'] + - [... from 'right] + - [... from 'back'] - - Each node information is composed of 9 features: - - #1: if own target lies on the explored branch the current distance from the agent in number of cells is stored. - - #2: if another agents target is detected the distance in number of cells from the agents current location - is stored - - #3: if another agent is detected the distance in number of cells from current agent position is stored. - - #4: possible conflict detected - tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the - distance in number of cells from current agent position - - 0 = No other agent reserve the same cell at similar time - - #5: if an not usable switch (for agent) is detected we store the distance. - - #6: This feature stores the distance in number of cells to the next branching (current node) - - #7: minimum distance from node to the agent's target given the direction of the agent if this path is chosen - - #8: agent in the same direction - n = number of agents present same direction - (possible future use: number of other agents in the same direction in this branch) - 0 = no agent present same direction - - #9: agent in the opposite direction - n = number of agents present other direction than myself (so conflict) - (possible future use: number of other agents in other direction in this branch, ie. number of conflicts) - 0 = no agent present other direction than myself - - #10: malfunctioning/blokcing agents - n = number of time steps the oberved agent remains blocked - - #11: slowest observed speed of an agent in same direction - 1 if no agent is observed - - min_fractional speed otherwise - - - - - - Missing/padding nodes are filled in with -inf (truncated). - Missing values in present node are filled in with +inf (truncated). - - - In case of the root node, the values are [0, 0, 0, 0, distance from agent to target, own malfunction, own speed] - In case the target node is reached, the values are [0, 0, 0, 0, 0]. - """ - - # Update local lookup table for all agents' positions - self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents} - self.location_has_agent_direction = {tuple(agent.position): agent.direction for agent in self.env.agents} - self.location_has_agent_speed = {tuple(agent.position): agent.speed_data['speed'] for agent in self.env.agents} - self.location_has_agent_malfunction = {tuple(agent.position): agent.malfunction_data['malfunction'] for agent in - self.env.agents} - - if handle > len(self.env.agents): - print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents)) - agent = self.env.agents[handle] # TODO: handle being treated as index - possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) - num_transitions = np.count_nonzero(possible_transitions) - - # Root node - current position - # Here information about the agent itself is stored - observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0, - agent.malfunction_data['malfunction'], agent.speed_data['speed']] - - visited = set() - - # Start from the current orientation, and see which transitions are available; - # organize them as [left, forward, right, back], relative to the current orientation - # If only one transition is possible, the tree is oriented with this transition as the forward branch. - orientation = agent.direction - - if num_transitions == 1: - orientation = np.argmax(possible_transitions) - - for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]: - if possible_transitions[branch_direction]: - new_cell = get_new_position(agent.position, branch_direction) - branch_observation, branch_visited = \ - self._explore_branch(handle, new_cell, branch_direction, 1, 1) - observation = observation + branch_observation - visited = visited.union(branch_visited) - else: - # add cells filled with infinity if no transition is possible - observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth) - self.env.dev_obs_dict[handle] = visited - - return observation - - def _num_cells_to_fill_in(self, remaining_depth): - """Computes the length of observation vector: sum_{i=0,depth-1} 2^i * observation_dim.""" - num_observations = 0 - pow4 = 1 - for i in range(remaining_depth): - num_observations += pow4 - pow4 *= 4 - return num_observations * self.observation_dim - - def _explore_branch(self, handle, position, direction, tot_dist, depth): - """ - Utility function to compute tree-based observations. - We walk along the branch and collect the information documented in the get() function. - If there is a branching point a new node is created and each possible branch is explored. - """ - - # [Recursive branch opened] - if depth >= self.max_depth + 1: - return [], [] - - # Continue along direction until next switch or - # until no transitions are possible along the current direction (i.e., dead-ends) - # We treat dead-ends as nodes, instead of going back, to avoid loops - exploring = True - last_is_switch = False - last_is_dead_end = False - last_is_terminal = False # wrong cell OR cycle; either way, we don't want the agent to land here - last_is_target = False - - visited = set() - agent = self.env.agents[handle] - time_per_cell = np.reciprocal(agent.speed_data["speed"]) - own_target_encountered = np.inf - other_agent_encountered = np.inf - other_target_encountered = np.inf - potential_conflict = np.inf - unusable_switch = np.inf - other_agent_same_direction = 0 - other_agent_opposite_direction = 0 - malfunctioning_agent = 0. - min_fractional_speed = 1. - num_steps = 1 - while exploring: - # ############################# - # ############################# - # Modify here to compute any useful data required to build the end node's features. This code is called - # for each cell visited between the previous branching node and the next switch / target / dead-end. - if position in self.location_has_agent: - if tot_dist < other_agent_encountered: - other_agent_encountered = tot_dist - - # Check if any of the observed agents is malfunctioning, store agent with longest duration left - if self.location_has_agent_malfunction[position] > malfunctioning_agent: - malfunctioning_agent = self.location_has_agent_malfunction[position] - - if self.location_has_agent_direction[position] == direction: - # Cummulate the number of agents on branch with same direction - other_agent_same_direction += 1 - - # Check fractional speed of agents - current_fractional_speed = self.location_has_agent_speed[position] - if current_fractional_speed < min_fractional_speed: - min_fractional_speed = current_fractional_speed - - if self.location_has_agent_direction[position] != direction: - # Cummulate the number of agents on branch with other direction - other_agent_opposite_direction += 1 - - # Check number of possible transitions for agent and total number of transitions in cell (type) - cell_transitions = self.env.rail.get_transitions(*position, direction) - transition_bit = bin(self.env.rail.get_full_transitions(*position)) - total_transitions = transition_bit.count("1") - crossing_found = False - if int(transition_bit, 2) == int('1000010000100001', 2): - crossing_found = True - - # Register possible future conflict - predicted_time = int(tot_dist * time_per_cell) - if self.predictor and predicted_time < self.max_prediction_depth: - int_position = coordinate_to_position(self.env.width, [position]) - if tot_dist < self.max_prediction_depth: - - pre_step = max(0, predicted_time - 1) - post_step = min(self.max_prediction_depth - 1, predicted_time + 1) - - # Look for conflicting paths at distance tot_dist - if int_position in np.delete(self.predicted_pos[predicted_time], handle, 0): - conflicting_agent = np.where(self.predicted_pos[predicted_time] == int_position) - for ca in conflicting_agent[0]: - if direction != self.predicted_dir[predicted_time][ca] and cell_transitions[ - self._reverse_dir( - self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict: - potential_conflict = tot_dist - if self.env.dones[ca] and tot_dist < potential_conflict: - potential_conflict = tot_dist - - # Look for conflicting paths at distance num_step-1 - elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0): - conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position) - for ca in conflicting_agent[0]: - if direction != self.predicted_dir[pre_step][ca] \ - and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \ - and tot_dist < potential_conflict: # noqa: E125 - potential_conflict = tot_dist - if self.env.dones[ca] and tot_dist < potential_conflict: - potential_conflict = tot_dist - - # Look for conflicting paths at distance num_step+1 - elif int_position in np.delete(self.predicted_pos[post_step], handle, 0): - conflicting_agent = np.where(self.predicted_pos[post_step] == int_position) - for ca in conflicting_agent[0]: - if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir( - self.predicted_dir[post_step][ca])] == 1 \ - and tot_dist < potential_conflict: # noqa: E125 - potential_conflict = tot_dist - if self.env.dones[ca] and tot_dist < potential_conflict: - potential_conflict = tot_dist - - if position in self.location_has_target and position != agent.target: - if tot_dist < other_target_encountered: - other_target_encountered = tot_dist - - if position == agent.target and tot_dist < own_target_encountered: - own_target_encountered = tot_dist - - # ############################# - # ############################# - if (position[0], position[1], direction) in visited: - last_is_terminal = True - break - visited.add((position[0], position[1], direction)) - - # If the target node is encountered, pick that as node. Also, no further branching is possible. - if np.array_equal(position, self.env.agents[handle].target): - last_is_target = True - break - - # Check if crossing is found --> Not an unusable switch - if crossing_found: - # Treat the crossing as a straight rail cell - total_transitions = 2 - num_transitions = np.count_nonzero(cell_transitions) - - exploring = False - - # Detect Switches that can only be used by other agents. - if total_transitions > 2 > num_transitions and tot_dist < unusable_switch: - unusable_switch = tot_dist - - if num_transitions == 1: - # Check if dead-end, or if we can go forward along direction - nbits = total_transitions - if nbits == 1: - # Dead-end! - last_is_dead_end = True - - if not last_is_dead_end: - # Keep walking through the tree along `direction` - exploring = True - # convert one-hot encoding to 0,1,2,3 - direction = np.argmax(cell_transitions) - position = get_new_position(position, direction) - num_steps += 1 - tot_dist += 1 - elif num_transitions > 0: - # Switch detected - last_is_switch = True - break - - elif num_transitions == 0: - # Wrong cell type, but let's cover it and treat it as a dead-end, just in case - print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0], - position[1], direction) - last_is_terminal = True - break - - # `position` is either a terminal node or a switch - - # ############################# - # ############################# - # Modify here to append new / different features for each visited cell! - - if last_is_target: - observation = [own_target_encountered, - other_target_encountered, - other_agent_encountered, - potential_conflict, - unusable_switch, - tot_dist, - 0, - other_agent_same_direction, - other_agent_opposite_direction, - malfunctioning_agent, - min_fractional_speed - ] - - elif last_is_terminal: - observation = [own_target_encountered, - other_target_encountered, - other_agent_encountered, - potential_conflict, - unusable_switch, - np.inf, - self.distance_map[handle, position[0], position[1], direction], - other_agent_same_direction, - other_agent_opposite_direction, - malfunctioning_agent, - min_fractional_speed - ] - - else: - observation = [own_target_encountered, - other_target_encountered, - other_agent_encountered, - potential_conflict, - unusable_switch, - tot_dist, - self.distance_map[handle, position[0], position[1], direction], - other_agent_same_direction, - other_agent_opposite_direction, - malfunctioning_agent, - min_fractional_speed - ] - # ############################# - # ############################# - # Start from the current orientation, and see which transitions are available; - # organize them as [left, forward, right, back], relative to the current orientation - # Get the possible transitions - possible_transitions = self.env.rail.get_transitions(*position, direction) - for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]: - if last_is_dead_end and self.env.rail.get_transition((*position, direction), - (branch_direction + 2) % 4): - # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes - # it back - new_cell = get_new_position(position, (branch_direction + 2) % 4) - branch_observation, branch_visited = self._explore_branch(handle, - new_cell, - (branch_direction + 2) % 4, - tot_dist + 1, - depth + 1) - observation = observation + branch_observation - if len(branch_visited) != 0: - visited = visited.union(branch_visited) - elif last_is_switch and possible_transitions[branch_direction]: - new_cell = get_new_position(position, branch_direction) - branch_observation, branch_visited = self._explore_branch(handle, - new_cell, - branch_direction, - tot_dist + 1, - depth + 1) - observation = observation + branch_observation - if len(branch_visited) != 0: - visited = visited.union(branch_visited) - else: - # no exploring possible, add just cells with infinity - observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth - depth) - - return observation, visited - - def util_print_obs_subtree(self, tree): - """ - Utility function to pretty-print tree observations returned by this object. - """ - pp = pprint.PrettyPrinter(indent=4) - pp.pprint(self.unfold_observation_tree(tree)) - - def unfold_observation_tree(self, tree, current_depth=0, actions_for_display=True): - """ - Utility function to pretty-print tree observations returned by this object. - """ - if len(tree) < self.observation_dim: - return - - depth = 0 - tmp = len(tree) / self.observation_dim - 1 - pow4 = 4 - while tmp > 0: - tmp -= pow4 - depth += 1 - pow4 *= 4 - - unfolded = {} - unfolded[''] = tree[0:self.observation_dim] - child_size = (len(tree) - self.observation_dim) // 4 - for child in range(4): - child_tree = tree[(self.observation_dim + child * child_size): - (self.observation_dim + (child + 1) * child_size)] - observation_tree = self.unfold_observation_tree(child_tree, current_depth=current_depth + 1) - if observation_tree is not None: - if actions_for_display: - label = self.tree_explorted_actions_char[child] - else: - label = self.tree_explored_actions[child] - unfolded[label] = observation_tree - return unfolded - - def _set_env(self, env): - self.env = env - if self.predictor: - self.predictor._set_env(self.env) - - def _reverse_dir(self, direction): - return int((direction + 2) % 4) - - -class GlobalObsForRailEnv(ObservationBuilder): - """ - Gives a global observation of the entire rail environment. - The observation is composed of the following elements: - - - transition map array with dimensions (env.height, env.width, 16), - assuming 16 bits encoding of transitions. - - - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent - target and the positions of the other agents targets. - - - A 3D array (map_height, map_width, 8) with the 4 first channels containing the one hot encoding - of the direction of the given agent and the 4 second channels containing the positions - of the other agents at their position coordinates. - """ - - def __init__(self): - self.observation_space = () - super(GlobalObsForRailEnv, self).__init__() - - def _set_env(self, env): - super()._set_env(env) - - self.observation_space = [4, self.env.height, self.env.width] - - def reset(self): - self.rail_obs = np.zeros((self.env.height, self.env.width, 16)) - for i in range(self.rail_obs.shape[0]): - for j in range(self.rail_obs.shape[1]): - bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]] - bitlist = [0] * (16 - len(bitlist)) + bitlist - self.rail_obs[i, j] = np.array(bitlist) - - def get(self, handle): - obs_targets = np.zeros((self.env.height, self.env.width, 2)) - obs_agents_state = np.zeros((self.env.height, self.env.width, 8)) - agents = self.env.agents - agent = agents[handle] - - direction = np.zeros(4) - direction[agent.direction] = 1 - agent_pos = agents[handle].position - obs_agents_state[agent_pos][:4] = direction - obs_targets[agent.target][0] += 1 - - for i in range(len(agents)): - if i != handle: # TODO: handle used as index...? - agent2 = agents[i] - obs_agents_state[agent2.position][4 + agent2.direction] = 1 - obs_targets[agent2.target][1] += 1 - - direction = self._get_one_hot_for_agent_direction(agent) - - return self.rail_obs, obs_agents_state, obs_targets, direction - - -class GlobalObsForRailEnvDirectionDependent(ObservationBuilder): - """ - Gives a global observation of the entire rail environment. - The observation is composed of the following elements: - - - transition map array with dimensions (env.height, env.width, 16), - assuming 16 bits encoding of transitions, flipped in the direction of the agent - (the agent is always heading north on the flipped view). - - - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent - target and the positions of the other agents targets, also flipped depending on the agent's direction. - - - A 3D array (map_height, map_width, 5) containing the one hot encoding of the direction of the other - agents at their position coordinates, and the last channel containing the position of the given agent. - - - A 4 elements array with one hot encoding of the direction. - """ - - def __init__(self): - self.observation_space = () - super(GlobalObsForRailEnvDirectionDependent, self).__init__() - - def _set_env(self, env): - super()._set_env(env) - - self.observation_space = [4, self.env.height, self.env.width] - - def reset(self): - self.rail_obs = np.zeros((self.env.height, self.env.width, 16)) - for i in range(self.rail_obs.shape[0]): - for j in range(self.rail_obs.shape[1]): - bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]] - bitlist = [0] * (16 - len(bitlist)) + bitlist - self.rail_obs[i, j] = np.array(bitlist) - - def get(self, handle): - obs_targets = np.zeros((self.env.height, self.env.width, 2)) - obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - agents = self.env.agents - agent = agents[handle] - direction = agent.direction - - idx = np.tile(np.arange(16), 2) - - rail_obs = self.rail_obs[:, :, idx[direction * 4: direction * 4 + 16]] - - if direction == 1: - rail_obs = np.flip(rail_obs, axis=1) - elif direction == 2: - rail_obs = np.flip(rail_obs) - elif direction == 3: - rail_obs = np.flip(rail_obs, axis=0) - - agent_pos = agents[handle].position - obs_agents_state[agent_pos][0] = 1 - obs_targets[agent.target][0] += 1 - - idx = np.tile(np.arange(4), 2) - for i in range(len(agents)): - if i != handle: # TODO: handle used as index...? - agent2 = agents[i] - obs_agents_state[agent2.position][1 + idx[4 + (agent2.direction - direction)]] = 1 - obs_targets[agent2.target][1] += 1 - - direction = self._get_one_hot_for_agent_direction(agent) - - return rail_obs, obs_agents_state, obs_targets, direction - - -class LocalObsForRailEnv(ObservationBuilder): - """ - Gives a local observation of the rail environment around the agent. - The observation is composed of the following elements: - - - transition map array of the local environment around the given agent, - with dimensions (view_height,2*view_width+1, 16), - assuming 16 bits encoding of transitions. - - - Two 2D arrays (view_height,2*view_width+1, 2) containing respectively, - if they are in the agent's vision range, its target position, the positions of the other targets. - - - A 2D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions - of the other agents at their position coordinates, if they are in the agent's vision range. - - - A 4 elements array with one hot encoding of the direction. - - Use the parameters view_width and view_height to define the rectangular view of the agent. - The center parameters moves the agent along the height axis of this rectangle. If it is 0 the agent only has - observation in front of it. - """ - - def __init__(self, view_width, view_height, center): - - super(LocalObsForRailEnv, self).__init__() - self.view_width = view_width - self.view_height = view_height - self.center = center - self.max_padding = max(self.view_width, self.view_height - self.center) - - def reset(self): - # We build the transition map with a view_radius empty cells expansion on each side. - # This helps to collect the local transition map view when the agent is close to a border. - self.max_padding = max(self.view_width, self.view_height) - self.rail_obs = np.zeros((self.env.height, - self.env.width, 16)) - for i in range(self.env.height): - for j in range(self.env.width): - bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]] - bitlist = [0] * (16 - len(bitlist)) + bitlist - self.rail_obs[i, j] = np.array(bitlist) - - def get(self, handle): - agents = self.env.agents - agent = agents[handle] - - # Correct agents position for padding - # agent_rel_pos[0] = agent.position[0] + self.max_padding - # agent_rel_pos[1] = agent.position[1] + self.max_padding - - # Collect visible cells as set to be plotted - visited, rel_coords = self.field_of_view(agent.position, agent.direction, ) - local_rail_obs = None - - # Add the visible cells to the observed cells - self.env.dev_obs_dict[handle] = set(visited) - - # Locate observed agents and their coresponding targets - local_rail_obs = np.zeros((self.view_height, 2 * self.view_width + 1, 16)) - obs_map_state = np.zeros((self.view_height, 2 * self.view_width + 1, 2)) - obs_other_agents_state = np.zeros((self.view_height, 2 * self.view_width + 1, 4)) - _idx = 0 - for pos in visited: - curr_rel_coord = rel_coords[_idx] - local_rail_obs[curr_rel_coord[0], curr_rel_coord[1], :] = self.rail_obs[pos[0], pos[1], :] - if pos == agent.target: - obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 0] = 1 - else: - for tmp_agent in agents: - if pos == tmp_agent.target: - obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 1] = 1 - if pos != agent.position: - for tmp_agent in agents: - if pos == tmp_agent.position: - obs_other_agents_state[curr_rel_coord[0], curr_rel_coord[1], :] = np.identity(4)[ - tmp_agent.direction] - - _idx += 1 - - direction = np.identity(4)[agent.direction] - return local_rail_obs, obs_map_state, obs_other_agents_state, direction - - def get_many(self, handles=None): - """ - Called whenever an observation has to be computed for the `env' environment, for each agent with handle - in the `handles' list. - """ - - observations = {} - for h in handles: - observations[h] = self.get(h) - return observations - - def field_of_view(self, position, direction, state=None): - # Compute the local field of view for an agent in the environment - data_collection = False - if state is not None: - temp_visible_data = np.zeros(shape=(self.view_height, 2 * self.view_width + 1, 16)) - data_collection = True - if direction == 0: - origin = (position[0] + self.center, position[1] - self.view_width) - elif direction == 1: - origin = (position[0] - self.view_width, position[1] - self.center) - elif direction == 2: - origin = (position[0] - self.center, position[1] + self.view_width) - else: - origin = (position[0] + self.view_width, position[1] + self.center) - visible = list() - rel_coords = list() - for h in range(self.view_height): - for w in range(2 * self.view_width + 1): - if direction == 0: - if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width: - visible.append((origin[0] - h, origin[1] + w)) - rel_coords.append((h, w)) - # if data_collection: - # temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :] - elif direction == 1: - if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width: - visible.append((origin[0] + w, origin[1] + h)) - rel_coords.append((h, w)) - # if data_collection: - # temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :] - elif direction == 2: - if 0 <= origin[0] + h < self.env.height and 0 <= origin[1] - w < self.env.width: - visible.append((origin[0] + h, origin[1] - w)) - rel_coords.append((h, w)) - # if data_collection: - # temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :] - else: - if 0 <= origin[0] - w < self.env.height and 0 <= origin[1] - h < self.env.width: - visible.append((origin[0] - w, origin[1] - h)) - rel_coords.append((h, w)) - # if data_collection: - # temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :] - if data_collection: - return temp_visible_data - else: - return visible, rel_coords diff --git a/torch_training/predictors/predictions.py b/torch_training/predictors/predictions.py deleted file mode 100644 index 8306b72..0000000 --- a/torch_training/predictors/predictions.py +++ /dev/null @@ -1,179 +0,0 @@ -""" -Collection of environment-specific PredictionBuilder. -""" - -import numpy as np - -from flatland.core.env_prediction_builder import PredictionBuilder -from flatland.core.grid.grid4_utils import get_new_position -from flatland.envs.rail_env import RailEnvActions - - -class DummyPredictorForRailEnv(PredictionBuilder): - """ - DummyPredictorForRailEnv object. - - This object returns predictions for agents in the RailEnv environment. - The prediction acts as if no other agent is in the environment and always takes the forward action. - """ - - def get(self, custom_args=None, handle=None): - """ - Called whenever get_many in the observation build is called. - - Parameters - ------- - custom_args: dict - Not used in this dummy implementation. - handle : int (optional) - Handle of the agent for which to compute the observation vector. - - Returns - ------- - np.array - Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements: - - time_offset - - position axis 0 - - position axis 1 - - direction - - action taken to come here - The prediction at 0 is the current position, direction etc. - - """ - agents = self.env.agents - if handle: - agents = [self.env.agents[handle]] - - prediction_dict = {} - - for agent in agents: - action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT] - _agent_initial_position = agent.position - _agent_initial_direction = agent.direction - prediction = np.zeros(shape=(self.max_depth + 1, 5)) - prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0] - for index in range(1, self.max_depth + 1): - action_done = False - # if we're at the target, stop moving... - if agent.position == agent.target: - prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING] - - continue - for action in action_priorities: - cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ - self.env._check_action_on_agent(action, agent) - if all([new_cell_isValid, transition_isValid]): - # move and change direction to face the new_direction that was - # performed - agent.position = new_position - agent.direction = new_direction - prediction[index] = [index, *new_position, new_direction, action] - action_done = True - break - if not action_done: - raise Exception("Cannot move further. Something is wrong") - prediction_dict[agent.handle] = prediction - agent.position = _agent_initial_position - agent.direction = _agent_initial_direction - return prediction_dict - - -class ShortestPathPredictorForRailEnv(PredictionBuilder): - """ - ShortestPathPredictorForRailEnv object. - - This object returns shortest-path predictions for agents in the RailEnv environment. - The prediction acts as if no other agent is in the environment and always takes the forward action. - """ - - def __init__(self, max_depth=20): - # Initialize with depth 20 - self.max_depth = max_depth - - def get(self, custom_args=None, handle=None): - """ - Called whenever get_many in the observation build is called. - Requires distance_map to extract the shortest path. - - Parameters - ---------- - custom_args: dict - - distance_map : dict - handle : int, optional - Handle of the agent for which to compute the observation vector. - - Returns - ------- - np.array - Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements: - - time_offset - - position axis 0 - - position axis 1 - - direction - - action taken to come here - The prediction at 0 is the current position, direction etc. - """ - agents = self.env.agents - if handle: - agents = [self.env.agents[handle]] - assert custom_args is not None - distance_map = custom_args.get('distance_map') - assert distance_map is not None - - prediction_dict = {} - for agent in agents: - _agent_initial_position = agent.position - _agent_initial_direction = agent.direction - agent_speed = agent.speed_data["speed"] - times_per_cell = int(np.reciprocal(agent_speed)) - prediction = np.zeros(shape=(self.max_depth + 1, 5)) - prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0] - new_direction = _agent_initial_direction - new_position = _agent_initial_position - visited = set() - for index in range(1, self.max_depth + 1): - # if we're at the target, stop moving... - if agent.position == agent.target: - prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING] - visited.add((agent.position[0], agent.position[1], agent.direction)) - continue - if not agent.moving: - prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING] - visited.add((agent.position[0], agent.position[1], agent.direction)) - continue - # Take shortest possible path - cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) - - if np.sum(cell_transitions) == 1 and index % times_per_cell == 0: - new_direction = np.argmax(cell_transitions) - new_position = get_new_position(agent.position, new_direction) - elif np.sum(cell_transitions) > 1 and index % times_per_cell == 0: - min_dist = np.inf - no_dist_found = True - for direction in range(4): - if cell_transitions[direction] == 1: - neighbour_cell = get_new_position(agent.position, direction) - target_dist = distance_map[agent.handle, neighbour_cell[0], neighbour_cell[1], direction] - if target_dist < min_dist or no_dist_found: - min_dist = target_dist - new_direction = direction - no_dist_found = False - new_position = get_new_position(agent.position, new_direction) - elif index % times_per_cell == 0: - raise Exception("No transition possible {}".format(cell_transitions)) - - # update the agent's position and direction - agent.position = new_position - agent.direction = new_direction - - # prediction is ready - prediction[index] = [index, *new_position, new_direction, 0] - visited.add((new_position[0], new_position[1], new_direction)) - self.env.dev_pred_dict[agent.handle] = visited - prediction_dict[agent.handle] = prediction - - # cleanup: reset initial position - agent.position = _agent_initial_position - agent.direction = _agent_initial_direction - - return prediction_dict diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py index 589d610..4befcd0 100644 --- a/torch_training/render_agent_behavior.py +++ b/torch_training/render_agent_behavior.py @@ -101,7 +101,7 @@ dones_list = [] action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() -agent = Agent(state_size, action_size, "FC", 0) +agent = Agent(state_size, action_size, 0) with path(torch_training.Nets, "navigator_checkpoint10700.pth") as file_in: agent.qnetwork_local.load_state_dict(torch.load(file_in)) diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index f69929f..80e4767 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -11,7 +11,7 @@ sys.path.append(str(base_dir)) import matplotlib.pyplot as plt import numpy as np import torch -from dueling_double_dqn import Agent +from torch_training.dueling_double_dqn import Agent from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv -- GitLab