From 81513d7b1d15bda2c4f8cb463df7a792abad8018 Mon Sep 17 00:00:00 2001
From: u229589 <christian.baumberger@sbb.ch>
Date: Thu, 26 Sep 2019 10:01:58 +0200
Subject: [PATCH] remove unused observation and prediction files

---
 scoring/score_test.py                         |   2 +-
 torch_training/dueling_double_dqn.py          |  13 +-
 torch_training/model.py                       |  35 -
 torch_training/multi_agent_inference.py       |   6 +-
 torch_training/multi_agent_training.py        |   2 +-
 .../multi_agent_two_time_step_training.py     |   7 +-
 .../observation_builders/observations.py      | 865 ------------------
 torch_training/predictors/predictions.py      | 179 ----
 torch_training/render_agent_behavior.py       |   2 +-
 torch_training/training_navigation.py         |   2 +-
 10 files changed, 15 insertions(+), 1098 deletions(-)
 delete mode 100644 torch_training/observation_builders/observations.py
 delete mode 100644 torch_training/predictors/predictions.py

diff --git a/scoring/score_test.py b/scoring/score_test.py
index 5665d44..fa65f4c 100644
--- a/scoring/score_test.py
+++ b/scoring/score_test.py
@@ -28,7 +28,7 @@ test_dones = []
 sequential_agent_test = False
 
 # Load your agent
-agent = Agent(state_size, action_size, "FC", 0)
+agent = Agent(state_size, action_size, 0)
 agent.qnetwork_local.load_state_dict(torch.load('../torch_training/Nets/avoid_checkpoint60000.pth'))
 
 # Load the necessary Observation Builder and Predictor
diff --git a/torch_training/dueling_double_dqn.py b/torch_training/dueling_double_dqn.py
index cf2f7d5..5dfa5c1 100644
--- a/torch_training/dueling_double_dqn.py
+++ b/torch_training/dueling_double_dqn.py
@@ -8,7 +8,7 @@ import torch
 import torch.nn.functional as F
 import torch.optim as optim
 
-from torch_training.model import QNetwork, QNetwork2
+from torch_training.model import QNetwork
 
 BUFFER_SIZE = int(1e5)  # replay buffer size
 BATCH_SIZE = 512  # minibatch size
@@ -27,7 +27,7 @@ print(device)
 class Agent:
     """Interacts with and learns from the environment."""
 
-    def __init__(self, state_size, action_size, net_type, seed, double_dqn=True, input_channels=5):
+    def __init__(self, state_size, action_size, seed, double_dqn=True, input_channels=5):
         """Initialize an Agent object.
 
         Params
@@ -39,15 +39,10 @@ class Agent:
         self.state_size = state_size
         self.action_size = action_size
         self.seed = random.seed(seed)
-        self.version = net_type
         self.double_dqn = double_dqn
         # Q-Network
-        if self.version == "Conv":
-            self.qnetwork_local = QNetwork2(state_size, action_size, seed, input_channels).to(device)
-            self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
-        else:
-            self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device)
-            self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
+        self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device)
+        self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
 
         self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)
 
diff --git a/torch_training/model.py b/torch_training/model.py
index 7a5b3d6..f7bc3d5 100644
--- a/torch_training/model.py
+++ b/torch_training/model.py
@@ -24,38 +24,3 @@ class QNetwork(nn.Module):
         adv = F.relu(self.fc2_adv(adv))
         adv = self.fc3_adv(adv)
         return val + adv - adv.mean()
-
-
-class QNetwork2(nn.Module):
-    def __init__(self, state_size, action_size, seed, input_channels, hidsize1=128, hidsize2=64):
-        super(QNetwork2, self).__init__()
-        self.conv1 = nn.Conv2d(input_channels, 16, kernel_size=3, stride=1)
-        self.bn1 = nn.BatchNorm2d(16)
-        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=3)
-        self.bn2 = nn.BatchNorm2d(32)
-        self.conv3 = nn.Conv2d(32, 64, kernel_size=5, stride=3)
-        self.bn3 = nn.BatchNorm2d(64)
-
-        self.fc1_val = nn.Linear(6400, hidsize1)
-        self.fc2_val = nn.Linear(hidsize1, hidsize2)
-        self.fc3_val = nn.Linear(hidsize2, 1)
-
-        self.fc1_adv = nn.Linear(6400, hidsize1)
-        self.fc2_adv = nn.Linear(hidsize1, hidsize2)
-        self.fc3_adv = nn.Linear(hidsize2, action_size)
-
-    def forward(self, x):
-        x = F.relu(self.conv1(x))
-        x = F.relu(self.conv2(x))
-        x = F.relu(self.conv3(x))
-
-        # value function approximation
-        val = F.relu(self.fc1_val(x.view(x.size(0), -1)))
-        val = F.relu(self.fc2_val(val))
-        val = self.fc3_val(val)
-
-        # advantage calculation
-        adv = F.relu(self.fc1_adv(x.view(x.size(0), -1)))
-        adv = F.relu(self.fc2_adv(adv))
-        adv = self.fc3_adv(adv)
-        return val + adv - adv.mean()
diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py
index 2fbbe61..96c8ee7 100644
--- a/torch_training/multi_agent_inference.py
+++ b/torch_training/multi_agent_inference.py
@@ -4,8 +4,8 @@ from collections import deque
 import numpy as np
 import torch
 from importlib_resources import path
-from observation_builders.observations import TreeObsForRailEnv
-from predictors.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 
 import torch_training.Nets
 from flatland.envs.rail_env import RailEnv
@@ -87,7 +87,7 @@ dones_list = []
 action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
-agent = Agent(state_size, action_size, "FC", 0)
+agent = Agent(state_size, action_size, 0)
 with path(torch_training.Nets, "avoid_checkpoint100.pth") as file_in:
     agent.qnetwork_local.load_state_dict(torch.load(file_in))
 
diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py
index 222430d..6747460 100644
--- a/torch_training/multi_agent_training.py
+++ b/torch_training/multi_agent_training.py
@@ -121,7 +121,7 @@ def main(argv):
     observation_radius = 10
 
     # Initialize the agent
-    agent = Agent(state_size, action_size, "FC", 0)
+    agent = Agent(state_size, action_size, 0)
 
     # Here you can pre-load an agent
     if False:
diff --git a/torch_training/multi_agent_two_time_step_training.py b/torch_training/multi_agent_two_time_step_training.py
index 08cd84c..2687cd5 100644
--- a/torch_training/multi_agent_two_time_step_training.py
+++ b/torch_training/multi_agent_two_time_step_training.py
@@ -43,10 +43,10 @@ def main(argv):
     min_dist = int(0.75 * min(x_dim, y_dim))
     tree_depth = 3
     print("main2")
+    demo = False
 
     # Get an observation builder and predictor
-    predictor = ShortestPathPredictorForRailEnv()
-    observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=predictor())
+    observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv())
 
     env = RailEnv(width=x_dim,
                   height=y_dim,
@@ -87,7 +87,7 @@ def main(argv):
     agent_obs = [None] * env.get_num_agents()
     agent_next_obs = [None] * env.get_num_agents()
     # Initialize the agent
-    agent = Agent(state_size, action_size, "FC", 0)
+    agent = Agent(state_size, action_size, 0)
 
     # Here you can pre-load an agent
     if False:
@@ -132,6 +132,7 @@ def main(argv):
         # Build agent specific observations
         for a in range(env.get_num_agents()):
             data, distance, agent_data = split_tree(tree=np.array(obs[a]),
+                                                    num_features_per_node=11,
                                                     current_depth=0)
             data = norm_obs_clip(data)
             distance = norm_obs_clip(distance)
diff --git a/torch_training/observation_builders/observations.py b/torch_training/observation_builders/observations.py
deleted file mode 100644
index 66c3830..0000000
--- a/torch_training/observation_builders/observations.py
+++ /dev/null
@@ -1,865 +0,0 @@
-"""
-Collection of environment-specific ObservationBuilder.
-"""
-import pprint
-from collections import deque
-
-import numpy as np
-
-from flatland.core.env_observation_builder import ObservationBuilder
-from flatland.core.grid.grid4_utils import get_new_position
-from flatland.core.grid.grid_utils import coordinate_to_position
-
-
-class TreeObsForRailEnv(ObservationBuilder):
-    """
-    TreeObsForRailEnv object.
-
-    This object returns observation vectors for agents in the RailEnv environment.
-    The information is local to each agent and exploits the graph structure of the rail
-    network to simplify the representation of the state of the environment for each agent.
-
-    For details about the features in the tree observation see the get() function.
-    """
-
-    def __init__(self, max_depth, predictor=None):
-        super().__init__()
-        self.max_depth = max_depth
-        self.observation_dim = 11
-        # Compute the size of the returned observation vector
-        size = 0
-        pow4 = 1
-        for i in range(self.max_depth + 1):
-            size += pow4
-            pow4 *= 4
-        self.observation_space = [size * self.observation_dim]
-        self.location_has_agent = {}
-        self.location_has_agent_direction = {}
-        self.predictor = predictor
-        self.agents_previous_reset = None
-        self.tree_explored_actions = [1, 2, 3, 0]
-        self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
-        self.distance_map = None
-        self.distance_map_computed = False
-
-    def reset(self):
-        agents = self.env.agents
-        nb_agents = len(agents)
-        compute_distance_map = True
-        if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset):
-            compute_distance_map = False
-            for i in range(nb_agents):
-                if agents[i].target != self.agents_previous_reset[i].target:
-                    compute_distance_map = True
-        # Don't compute the distance map if it was loaded
-        if self.agents_previous_reset is None and self.distance_map is not None:
-            self.location_has_target = {tuple(agent.target): 1 for agent in agents}
-            compute_distance_map = False
-
-        if compute_distance_map:
-            self._compute_distance_map()
-
-        self.agents_previous_reset = agents
-
-    def _compute_distance_map(self):
-        agents = self.env.agents
-        # For testing only --> To assert if a distance map need to be recomputed.
-        self.distance_map_computed = True
-        nb_agents = len(agents)
-        self.distance_map = np.inf * np.ones(shape=(nb_agents,
-                                                    self.env.height,
-                                                    self.env.width,
-                                                    4))
-        self.max_dist = np.zeros(nb_agents)
-        self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
-        # Update local lookup table for all agents' target locations
-        self.location_has_target = {tuple(agent.target): 1 for agent in agents}
-
-    def _distance_map_walker(self, position, target_nr):
-        """
-        Utility function to compute distance maps from each cell in the rail network (and each possible
-        orientation within it) to each agent's target cell.
-        """
-        # Returns max distance to target, from the farthest away node, while filling in distance_map
-        self.distance_map[target_nr, position[0], position[1], :] = 0
-
-        # Fill in the (up to) 4 neighboring nodes
-        # direction is the direction of movement, meaning that at least a possible orientation of an agent
-        # in cell (row,col) allows a movement in direction `direction`
-        nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1))
-
-        # BFS from target `position` to all the reachable nodes in the grid
-        # Stop the search if the target position is re-visited, in any direction
-        visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2),
-                   (position[0], position[1], 3)}
-
-        max_distance = 0
-
-        while nodes_queue:
-            node = nodes_queue.popleft()
-
-            node_id = (node[0], node[1], node[2])
-
-            if node_id not in visited:
-                visited.add(node_id)
-
-                # From the list of possible neighbors that have at least a path to the current node, only keep those
-                # whose new orientation in the current cell would allow a transition to direction node[2]
-                valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2])
-
-                for n in valid_neighbors:
-                    nodes_queue.append(n)
-
-                if len(valid_neighbors) > 0:
-                    max_distance = max(max_distance, node[3] + 1)
-
-        return max_distance
-
-    def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1):
-        """
-        Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
-        minimum distances from each target cell.
-        """
-        neighbors = []
-
-        possible_directions = [0, 1, 2, 3]
-        if enforce_target_direction >= 0:
-            # The agent must land into the current cell with orientation `enforce_target_direction`.
-            # This is only possible if the agent has arrived from the cell in the opposite direction!
-            possible_directions = [(enforce_target_direction + 2) % 4]
-
-        for neigh_direction in possible_directions:
-            new_cell = get_new_position(position, neigh_direction)
-
-            if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width:
-
-                desired_movement_from_new_cell = (neigh_direction + 2) % 4
-
-                # Check all possible transitions in new_cell
-                for agent_orientation in range(4):
-                    # Is a transition along movement `desired_movement_from_new_cell` to the current cell possible?
-                    is_valid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
-                                                            desired_movement_from_new_cell)
-
-                    if is_valid:
-                        """
-                        # TODO: check that it works with deadends! -- still bugged!
-                        movement = desired_movement_from_new_cell
-                        if isNextCellDeadEnd:
-                            movement = (desired_movement_from_new_cell+2) % 4
-                        """
-                        new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
-                                           current_distance + 1)
-                        neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
-                        self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
-
-        return neighbors
-
-    def get_many(self, handles=None):
-        """
-        Called whenever an observation has to be computed for the `env` environment, for each agent with handle
-        in the `handles` list.
-        """
-
-        if handles is None:
-            handles = []
-        if self.predictor:
-            self.max_prediction_depth = 0
-            self.predicted_pos = {}
-            self.predicted_dir = {}
-            self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map})
-            if self.predictions:
-
-                for t in range(len(self.predictions[0])):
-                    pos_list = []
-                    dir_list = []
-                    for a in handles:
-                        pos_list.append(self.predictions[a][t][1:3])
-                        dir_list.append(self.predictions[a][t][3])
-                    self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
-                    self.predicted_dir.update({t: dir_list})
-                self.max_prediction_depth = len(self.predicted_pos)
-        observations = {}
-        for h in handles:
-            observations[h] = self.get(h)
-        return observations
-
-    def get(self, handle):
-        """
-        Computes the current observation for agent `handle` in env
-
-        The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
-        movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
-        The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for
-        the transitions. The order is:
-            [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']
-
-        Each branch data is organized as:
-            [root node information] +
-            [recursive branch data from 'left'] +
-            [... from 'forward'] +
-            [... from 'right] +
-            [... from 'back']
-
-        Each node information is composed of 9 features:
-
-        #1: if own target lies on the explored branch the current distance from the agent in number of cells is stored.
-
-        #2: if another agents target is detected the distance in number of cells from the agents current location
-            is stored
-
-        #3: if another agent is detected the distance in number of cells from current agent position is stored.
-
-        #4: possible conflict detected
-            tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the
-             distance in number of cells from current agent position
-
-            0 = No other agent reserve the same cell at similar time
-
-        #5: if an not usable switch (for agent) is detected we store the distance.
-
-        #6: This feature stores the distance in number of cells to the next branching  (current node)
-
-        #7: minimum distance from node to the agent's target given the direction of the agent if this path is chosen
-
-        #8: agent in the same direction
-            n = number of agents present same direction
-                (possible future use: number of other agents in the same direction in this branch)
-            0 = no agent present same direction
-
-        #9: agent in the opposite direction
-            n = number of agents present other direction than myself (so conflict)
-                (possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
-            0 = no agent present other direction than myself
-
-        #10: malfunctioning/blokcing agents
-            n = number of time steps the oberved agent remains blocked
-
-        #11: slowest observed speed of an agent in same direction
-            1 if no agent is observed
-
-            min_fractional speed otherwise
-
-
-
-
-
-        Missing/padding nodes are filled in with -inf (truncated).
-        Missing values in present node are filled in with +inf (truncated).
-
-
-        In case of the root node, the values are [0, 0, 0, 0, distance from agent to target, own malfunction, own speed]
-        In case the target node is reached, the values are [0, 0, 0, 0, 0].
-        """
-
-        # Update local lookup table for all agents' positions
-        self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents}
-        self.location_has_agent_direction = {tuple(agent.position): agent.direction for agent in self.env.agents}
-        self.location_has_agent_speed = {tuple(agent.position): agent.speed_data['speed'] for agent in self.env.agents}
-        self.location_has_agent_malfunction = {tuple(agent.position): agent.malfunction_data['malfunction'] for agent in
-                                               self.env.agents}
-
-        if handle > len(self.env.agents):
-            print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
-        agent = self.env.agents[handle]  # TODO: handle being treated as index
-        possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
-        num_transitions = np.count_nonzero(possible_transitions)
-
-        # Root node - current position
-        # Here information about the agent itself is stored
-        observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0,
-                       agent.malfunction_data['malfunction'], agent.speed_data['speed']]
-
-        visited = set()
-
-        # Start from the current orientation, and see which transitions are available;
-        # organize them as [left, forward, right, back], relative to the current orientation
-        # If only one transition is possible, the tree is oriented with this transition as the forward branch.
-        orientation = agent.direction
-
-        if num_transitions == 1:
-            orientation = np.argmax(possible_transitions)
-
-        for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
-            if possible_transitions[branch_direction]:
-                new_cell = get_new_position(agent.position, branch_direction)
-                branch_observation, branch_visited = \
-                    self._explore_branch(handle, new_cell, branch_direction, 1, 1)
-                observation = observation + branch_observation
-                visited = visited.union(branch_visited)
-            else:
-                # add cells filled with infinity if no transition is possible
-                observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth)
-        self.env.dev_obs_dict[handle] = visited
-
-        return observation
-
-    def _num_cells_to_fill_in(self, remaining_depth):
-        """Computes the length of observation vector: sum_{i=0,depth-1} 2^i * observation_dim."""
-        num_observations = 0
-        pow4 = 1
-        for i in range(remaining_depth):
-            num_observations += pow4
-            pow4 *= 4
-        return num_observations * self.observation_dim
-
-    def _explore_branch(self, handle, position, direction, tot_dist, depth):
-        """
-        Utility function to compute tree-based observations.
-        We walk along the branch and collect the information documented in the get() function.
-        If there is a branching point a new node is created and each possible branch is explored.
-        """
-
-        # [Recursive branch opened]
-        if depth >= self.max_depth + 1:
-            return [], []
-
-        # Continue along direction until next switch or
-        # until no transitions are possible along the current direction (i.e., dead-ends)
-        # We treat dead-ends as nodes, instead of going back, to avoid loops
-        exploring = True
-        last_is_switch = False
-        last_is_dead_end = False
-        last_is_terminal = False  # wrong cell OR cycle;  either way, we don't want the agent to land here
-        last_is_target = False
-
-        visited = set()
-        agent = self.env.agents[handle]
-        time_per_cell = np.reciprocal(agent.speed_data["speed"])
-        own_target_encountered = np.inf
-        other_agent_encountered = np.inf
-        other_target_encountered = np.inf
-        potential_conflict = np.inf
-        unusable_switch = np.inf
-        other_agent_same_direction = 0
-        other_agent_opposite_direction = 0
-        malfunctioning_agent = 0.
-        min_fractional_speed = 1.
-        num_steps = 1
-        while exploring:
-            # #############################
-            # #############################
-            # Modify here to compute any useful data required to build the end node's features. This code is called
-            # for each cell visited between the previous branching node and the next switch / target / dead-end.
-            if position in self.location_has_agent:
-                if tot_dist < other_agent_encountered:
-                    other_agent_encountered = tot_dist
-
-                # Check if any of the observed agents is malfunctioning, store agent with longest duration left
-                if self.location_has_agent_malfunction[position] > malfunctioning_agent:
-                    malfunctioning_agent = self.location_has_agent_malfunction[position]
-
-                if self.location_has_agent_direction[position] == direction:
-                    # Cummulate the number of agents on branch with same direction
-                    other_agent_same_direction += 1
-
-                    # Check fractional speed of agents
-                    current_fractional_speed = self.location_has_agent_speed[position]
-                    if current_fractional_speed < min_fractional_speed:
-                        min_fractional_speed = current_fractional_speed
-
-                if self.location_has_agent_direction[position] != direction:
-                    # Cummulate the number of agents on branch with other direction
-                    other_agent_opposite_direction += 1
-
-            # Check number of possible transitions for agent and total number of transitions in cell (type)
-            cell_transitions = self.env.rail.get_transitions(*position, direction)
-            transition_bit = bin(self.env.rail.get_full_transitions(*position))
-            total_transitions = transition_bit.count("1")
-            crossing_found = False
-            if int(transition_bit, 2) == int('1000010000100001', 2):
-                crossing_found = True
-
-            # Register possible future conflict
-            predicted_time = int(tot_dist * time_per_cell)
-            if self.predictor and predicted_time < self.max_prediction_depth:
-                int_position = coordinate_to_position(self.env.width, [position])
-                if tot_dist < self.max_prediction_depth:
-
-                    pre_step = max(0, predicted_time - 1)
-                    post_step = min(self.max_prediction_depth - 1, predicted_time + 1)
-
-                    # Look for conflicting paths at distance tot_dist
-                    if int_position in np.delete(self.predicted_pos[predicted_time], handle, 0):
-                        conflicting_agent = np.where(self.predicted_pos[predicted_time] == int_position)
-                        for ca in conflicting_agent[0]:
-                            if direction != self.predicted_dir[predicted_time][ca] and cell_transitions[
-                                self._reverse_dir(
-                                    self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict:
-                                potential_conflict = tot_dist
-                            if self.env.dones[ca] and tot_dist < potential_conflict:
-                                potential_conflict = tot_dist
-
-                    # Look for conflicting paths at distance num_step-1
-                    elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0):
-                        conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
-                        for ca in conflicting_agent[0]:
-                            if direction != self.predicted_dir[pre_step][ca] \
-                                    and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \
-                                    and tot_dist < potential_conflict:  # noqa: E125
-                                potential_conflict = tot_dist
-                            if self.env.dones[ca] and tot_dist < potential_conflict:
-                                potential_conflict = tot_dist
-
-                    # Look for conflicting paths at distance num_step+1
-                    elif int_position in np.delete(self.predicted_pos[post_step], handle, 0):
-                        conflicting_agent = np.where(self.predicted_pos[post_step] == int_position)
-                        for ca in conflicting_agent[0]:
-                            if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir(
-                                    self.predicted_dir[post_step][ca])] == 1 \
-                                    and tot_dist < potential_conflict:  # noqa: E125
-                                potential_conflict = tot_dist
-                            if self.env.dones[ca] and tot_dist < potential_conflict:
-                                potential_conflict = tot_dist
-
-            if position in self.location_has_target and position != agent.target:
-                if tot_dist < other_target_encountered:
-                    other_target_encountered = tot_dist
-
-            if position == agent.target and tot_dist < own_target_encountered:
-                own_target_encountered = tot_dist
-
-            # #############################
-            # #############################
-            if (position[0], position[1], direction) in visited:
-                last_is_terminal = True
-                break
-            visited.add((position[0], position[1], direction))
-
-            # If the target node is encountered, pick that as node. Also, no further branching is possible.
-            if np.array_equal(position, self.env.agents[handle].target):
-                last_is_target = True
-                break
-
-            # Check if crossing is found --> Not an unusable switch
-            if crossing_found:
-                # Treat the crossing as a straight rail cell
-                total_transitions = 2
-            num_transitions = np.count_nonzero(cell_transitions)
-
-            exploring = False
-
-            # Detect Switches that can only be used by other agents.
-            if total_transitions > 2 > num_transitions and tot_dist < unusable_switch:
-                unusable_switch = tot_dist
-
-            if num_transitions == 1:
-                # Check if dead-end, or if we can go forward along direction
-                nbits = total_transitions
-                if nbits == 1:
-                    # Dead-end!
-                    last_is_dead_end = True
-
-                if not last_is_dead_end:
-                    # Keep walking through the tree along `direction`
-                    exploring = True
-                    # convert one-hot encoding to 0,1,2,3
-                    direction = np.argmax(cell_transitions)
-                    position = get_new_position(position, direction)
-                    num_steps += 1
-                    tot_dist += 1
-            elif num_transitions > 0:
-                # Switch detected
-                last_is_switch = True
-                break
-
-            elif num_transitions == 0:
-                # Wrong cell type, but let's cover it and treat it as a dead-end, just in case
-                print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0],
-                      position[1], direction)
-                last_is_terminal = True
-                break
-
-        # `position` is either a terminal node or a switch
-
-        # #############################
-        # #############################
-        # Modify here to append new / different features for each visited cell!
-
-        if last_is_target:
-            observation = [own_target_encountered,
-                           other_target_encountered,
-                           other_agent_encountered,
-                           potential_conflict,
-                           unusable_switch,
-                           tot_dist,
-                           0,
-                           other_agent_same_direction,
-                           other_agent_opposite_direction,
-                           malfunctioning_agent,
-                           min_fractional_speed
-                           ]
-
-        elif last_is_terminal:
-            observation = [own_target_encountered,
-                           other_target_encountered,
-                           other_agent_encountered,
-                           potential_conflict,
-                           unusable_switch,
-                           np.inf,
-                           self.distance_map[handle, position[0], position[1], direction],
-                           other_agent_same_direction,
-                           other_agent_opposite_direction,
-                           malfunctioning_agent,
-                           min_fractional_speed
-                           ]
-
-        else:
-            observation = [own_target_encountered,
-                           other_target_encountered,
-                           other_agent_encountered,
-                           potential_conflict,
-                           unusable_switch,
-                           tot_dist,
-                           self.distance_map[handle, position[0], position[1], direction],
-                           other_agent_same_direction,
-                           other_agent_opposite_direction,
-                           malfunctioning_agent,
-                           min_fractional_speed
-                           ]
-        # #############################
-        # #############################
-        # Start from the current orientation, and see which transitions are available;
-        # organize them as [left, forward, right, back], relative to the current orientation
-        # Get the possible transitions
-        possible_transitions = self.env.rail.get_transitions(*position, direction)
-        for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
-            if last_is_dead_end and self.env.rail.get_transition((*position, direction),
-                                                                 (branch_direction + 2) % 4):
-                # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
-                # it back
-                new_cell = get_new_position(position, (branch_direction + 2) % 4)
-                branch_observation, branch_visited = self._explore_branch(handle,
-                                                                          new_cell,
-                                                                          (branch_direction + 2) % 4,
-                                                                          tot_dist + 1,
-                                                                          depth + 1)
-                observation = observation + branch_observation
-                if len(branch_visited) != 0:
-                    visited = visited.union(branch_visited)
-            elif last_is_switch and possible_transitions[branch_direction]:
-                new_cell = get_new_position(position, branch_direction)
-                branch_observation, branch_visited = self._explore_branch(handle,
-                                                                          new_cell,
-                                                                          branch_direction,
-                                                                          tot_dist + 1,
-                                                                          depth + 1)
-                observation = observation + branch_observation
-                if len(branch_visited) != 0:
-                    visited = visited.union(branch_visited)
-            else:
-                # no exploring possible, add just cells with infinity
-                observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth - depth)
-
-        return observation, visited
-
-    def util_print_obs_subtree(self, tree):
-        """
-        Utility function to pretty-print tree observations returned by this object.
-        """
-        pp = pprint.PrettyPrinter(indent=4)
-        pp.pprint(self.unfold_observation_tree(tree))
-
-    def unfold_observation_tree(self, tree, current_depth=0, actions_for_display=True):
-        """
-        Utility function to pretty-print tree observations returned by this object.
-        """
-        if len(tree) < self.observation_dim:
-            return
-
-        depth = 0
-        tmp = len(tree) / self.observation_dim - 1
-        pow4 = 4
-        while tmp > 0:
-            tmp -= pow4
-            depth += 1
-            pow4 *= 4
-
-        unfolded = {}
-        unfolded[''] = tree[0:self.observation_dim]
-        child_size = (len(tree) - self.observation_dim) // 4
-        for child in range(4):
-            child_tree = tree[(self.observation_dim + child * child_size):
-                              (self.observation_dim + (child + 1) * child_size)]
-            observation_tree = self.unfold_observation_tree(child_tree, current_depth=current_depth + 1)
-            if observation_tree is not None:
-                if actions_for_display:
-                    label = self.tree_explorted_actions_char[child]
-                else:
-                    label = self.tree_explored_actions[child]
-                unfolded[label] = observation_tree
-        return unfolded
-
-    def _set_env(self, env):
-        self.env = env
-        if self.predictor:
-            self.predictor._set_env(self.env)
-
-    def _reverse_dir(self, direction):
-        return int((direction + 2) % 4)
-
-
-class GlobalObsForRailEnv(ObservationBuilder):
-    """
-    Gives a global observation of the entire rail environment.
-    The observation is composed of the following elements:
-
-        - transition map array with dimensions (env.height, env.width, 16),
-          assuming 16 bits encoding of transitions.
-
-        - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent
-         target and the positions of the other agents targets.
-
-        - A 3D array (map_height, map_width, 8) with the 4 first channels containing the one hot encoding
-          of the direction of the given agent and the 4 second channels containing the positions
-          of the other agents at their position coordinates.
-    """
-
-    def __init__(self):
-        self.observation_space = ()
-        super(GlobalObsForRailEnv, self).__init__()
-
-    def _set_env(self, env):
-        super()._set_env(env)
-
-        self.observation_space = [4, self.env.height, self.env.width]
-
-    def reset(self):
-        self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
-        for i in range(self.rail_obs.shape[0]):
-            for j in range(self.rail_obs.shape[1]):
-                bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
-                bitlist = [0] * (16 - len(bitlist)) + bitlist
-                self.rail_obs[i, j] = np.array(bitlist)
-
-    def get(self, handle):
-        obs_targets = np.zeros((self.env.height, self.env.width, 2))
-        obs_agents_state = np.zeros((self.env.height, self.env.width, 8))
-        agents = self.env.agents
-        agent = agents[handle]
-
-        direction = np.zeros(4)
-        direction[agent.direction] = 1
-        agent_pos = agents[handle].position
-        obs_agents_state[agent_pos][:4] = direction
-        obs_targets[agent.target][0] += 1
-
-        for i in range(len(agents)):
-            if i != handle:  # TODO: handle used as index...?
-                agent2 = agents[i]
-                obs_agents_state[agent2.position][4 + agent2.direction] = 1
-                obs_targets[agent2.target][1] += 1
-
-        direction = self._get_one_hot_for_agent_direction(agent)
-
-        return self.rail_obs, obs_agents_state, obs_targets, direction
-
-
-class GlobalObsForRailEnvDirectionDependent(ObservationBuilder):
-    """
-    Gives a global observation of the entire rail environment.
-    The observation is composed of the following elements:
-
-        - transition map array with dimensions (env.height, env.width, 16),
-          assuming 16 bits encoding of transitions, flipped in the direction of the agent
-          (the agent is always heading north on the flipped view).
-
-        - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent
-         target and the positions of the other agents targets, also flipped depending on the agent's direction.
-
-        - A 3D array (map_height, map_width, 5) containing the one hot encoding of the direction of the other
-          agents at their position coordinates, and the last channel containing the position of the given agent.
-
-        - A 4 elements array with one hot encoding of the direction.
-    """
-
-    def __init__(self):
-        self.observation_space = ()
-        super(GlobalObsForRailEnvDirectionDependent, self).__init__()
-
-    def _set_env(self, env):
-        super()._set_env(env)
-
-        self.observation_space = [4, self.env.height, self.env.width]
-
-    def reset(self):
-        self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
-        for i in range(self.rail_obs.shape[0]):
-            for j in range(self.rail_obs.shape[1]):
-                bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
-                bitlist = [0] * (16 - len(bitlist)) + bitlist
-                self.rail_obs[i, j] = np.array(bitlist)
-
-    def get(self, handle):
-        obs_targets = np.zeros((self.env.height, self.env.width, 2))
-        obs_agents_state = np.zeros((self.env.height, self.env.width, 5))
-        agents = self.env.agents
-        agent = agents[handle]
-        direction = agent.direction
-
-        idx = np.tile(np.arange(16), 2)
-
-        rail_obs = self.rail_obs[:, :, idx[direction * 4: direction * 4 + 16]]
-
-        if direction == 1:
-            rail_obs = np.flip(rail_obs, axis=1)
-        elif direction == 2:
-            rail_obs = np.flip(rail_obs)
-        elif direction == 3:
-            rail_obs = np.flip(rail_obs, axis=0)
-
-        agent_pos = agents[handle].position
-        obs_agents_state[agent_pos][0] = 1
-        obs_targets[agent.target][0] += 1
-
-        idx = np.tile(np.arange(4), 2)
-        for i in range(len(agents)):
-            if i != handle:  # TODO: handle used as index...?
-                agent2 = agents[i]
-                obs_agents_state[agent2.position][1 + idx[4 + (agent2.direction - direction)]] = 1
-                obs_targets[agent2.target][1] += 1
-
-        direction = self._get_one_hot_for_agent_direction(agent)
-
-        return rail_obs, obs_agents_state, obs_targets, direction
-
-
-class LocalObsForRailEnv(ObservationBuilder):
-    """
-    Gives a local observation of the rail environment around the agent.
-    The observation is composed of the following elements:
-
-        - transition map array of the local environment around the given agent,
-          with dimensions (view_height,2*view_width+1, 16),
-          assuming 16 bits encoding of transitions.
-
-        - Two 2D arrays (view_height,2*view_width+1, 2) containing respectively,
-        if they are in the agent's vision range, its target position, the positions of the other targets.
-
-        - A 2D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions
-          of the other agents at their position coordinates, if they are in the agent's vision range.
-
-        - A 4 elements array with one hot encoding of the direction.
-
-    Use the parameters view_width and view_height to define the rectangular view of the agent.
-    The center parameters moves the agent along the height axis of this rectangle. If it is 0 the agent only has
-    observation in front of it.
-    """
-
-    def __init__(self, view_width, view_height, center):
-
-        super(LocalObsForRailEnv, self).__init__()
-        self.view_width = view_width
-        self.view_height = view_height
-        self.center = center
-        self.max_padding = max(self.view_width, self.view_height - self.center)
-
-    def reset(self):
-        # We build the transition map with a view_radius empty cells expansion on each side.
-        # This helps to collect the local transition map view when the agent is close to a border.
-        self.max_padding = max(self.view_width, self.view_height)
-        self.rail_obs = np.zeros((self.env.height,
-                                  self.env.width, 16))
-        for i in range(self.env.height):
-            for j in range(self.env.width):
-                bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
-                bitlist = [0] * (16 - len(bitlist)) + bitlist
-                self.rail_obs[i, j] = np.array(bitlist)
-
-    def get(self, handle):
-        agents = self.env.agents
-        agent = agents[handle]
-
-        # Correct agents position for padding
-        # agent_rel_pos[0] = agent.position[0] + self.max_padding
-        # agent_rel_pos[1] = agent.position[1] + self.max_padding
-
-        # Collect visible cells as set to be plotted
-        visited, rel_coords = self.field_of_view(agent.position, agent.direction, )
-        local_rail_obs = None
-
-        # Add the visible cells to the observed cells
-        self.env.dev_obs_dict[handle] = set(visited)
-
-        # Locate observed agents and their coresponding targets
-        local_rail_obs = np.zeros((self.view_height, 2 * self.view_width + 1, 16))
-        obs_map_state = np.zeros((self.view_height, 2 * self.view_width + 1, 2))
-        obs_other_agents_state = np.zeros((self.view_height, 2 * self.view_width + 1, 4))
-        _idx = 0
-        for pos in visited:
-            curr_rel_coord = rel_coords[_idx]
-            local_rail_obs[curr_rel_coord[0], curr_rel_coord[1], :] = self.rail_obs[pos[0], pos[1], :]
-            if pos == agent.target:
-                obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 0] = 1
-            else:
-                for tmp_agent in agents:
-                    if pos == tmp_agent.target:
-                        obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 1] = 1
-            if pos != agent.position:
-                for tmp_agent in agents:
-                    if pos == tmp_agent.position:
-                        obs_other_agents_state[curr_rel_coord[0], curr_rel_coord[1], :] = np.identity(4)[
-                            tmp_agent.direction]
-
-            _idx += 1
-
-        direction = np.identity(4)[agent.direction]
-        return local_rail_obs, obs_map_state, obs_other_agents_state, direction
-
-    def get_many(self, handles=None):
-        """
-        Called whenever an observation has to be computed for the `env' environment, for each agent with handle
-        in the `handles' list.
-        """
-
-        observations = {}
-        for h in handles:
-            observations[h] = self.get(h)
-        return observations
-
-    def field_of_view(self, position, direction, state=None):
-        # Compute the local field of view for an agent in the environment
-        data_collection = False
-        if state is not None:
-            temp_visible_data = np.zeros(shape=(self.view_height, 2 * self.view_width + 1, 16))
-            data_collection = True
-        if direction == 0:
-            origin = (position[0] + self.center, position[1] - self.view_width)
-        elif direction == 1:
-            origin = (position[0] - self.view_width, position[1] - self.center)
-        elif direction == 2:
-            origin = (position[0] - self.center, position[1] + self.view_width)
-        else:
-            origin = (position[0] + self.view_width, position[1] + self.center)
-        visible = list()
-        rel_coords = list()
-        for h in range(self.view_height):
-            for w in range(2 * self.view_width + 1):
-                if direction == 0:
-                    if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
-                        visible.append((origin[0] - h, origin[1] + w))
-                        rel_coords.append((h, w))
-                    # if data_collection:
-                    #    temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :]
-                elif direction == 1:
-                    if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width:
-                        visible.append((origin[0] + w, origin[1] + h))
-                        rel_coords.append((h, w))
-                    # if data_collection:
-                    #    temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :]
-                elif direction == 2:
-                    if 0 <= origin[0] + h < self.env.height and 0 <= origin[1] - w < self.env.width:
-                        visible.append((origin[0] + h, origin[1] - w))
-                        rel_coords.append((h, w))
-                    # if data_collection:
-                    #    temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :]
-                else:
-                    if 0 <= origin[0] - w < self.env.height and 0 <= origin[1] - h < self.env.width:
-                        visible.append((origin[0] - w, origin[1] - h))
-                        rel_coords.append((h, w))
-                    # if data_collection:
-                    #    temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :]
-        if data_collection:
-            return temp_visible_data
-        else:
-            return visible, rel_coords
diff --git a/torch_training/predictors/predictions.py b/torch_training/predictors/predictions.py
deleted file mode 100644
index 8306b72..0000000
--- a/torch_training/predictors/predictions.py
+++ /dev/null
@@ -1,179 +0,0 @@
-"""
-Collection of environment-specific PredictionBuilder.
-"""
-
-import numpy as np
-
-from flatland.core.env_prediction_builder import PredictionBuilder
-from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.rail_env import RailEnvActions
-
-
-class DummyPredictorForRailEnv(PredictionBuilder):
-    """
-    DummyPredictorForRailEnv object.
-
-    This object returns predictions for agents in the RailEnv environment.
-    The prediction acts as if no other agent is in the environment and always takes the forward action.
-    """
-
-    def get(self, custom_args=None, handle=None):
-        """
-        Called whenever get_many in the observation build is called.
-
-        Parameters
-        -------
-        custom_args: dict
-            Not used in this dummy implementation.
-        handle : int (optional)
-            Handle of the agent for which to compute the observation vector.
-
-        Returns
-        -------
-        np.array
-            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
-            - time_offset
-            - position axis 0
-            - position axis 1
-            - direction
-            - action taken to come here
-            The prediction at 0 is the current position, direction etc.
-
-        """
-        agents = self.env.agents
-        if handle:
-            agents = [self.env.agents[handle]]
-
-        prediction_dict = {}
-
-        for agent in agents:
-            action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
-            _agent_initial_position = agent.position
-            _agent_initial_direction = agent.direction
-            prediction = np.zeros(shape=(self.max_depth + 1, 5))
-            prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
-            for index in range(1, self.max_depth + 1):
-                action_done = False
-                # if we're at the target, stop moving...
-                if agent.position == agent.target:
-                    prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
-
-                    continue
-                for action in action_priorities:
-                    cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
-                        self.env._check_action_on_agent(action, agent)
-                    if all([new_cell_isValid, transition_isValid]):
-                        # move and change direction to face the new_direction that was
-                        # performed
-                        agent.position = new_position
-                        agent.direction = new_direction
-                        prediction[index] = [index, *new_position, new_direction, action]
-                        action_done = True
-                        break
-                if not action_done:
-                    raise Exception("Cannot move further. Something is wrong")
-            prediction_dict[agent.handle] = prediction
-            agent.position = _agent_initial_position
-            agent.direction = _agent_initial_direction
-        return prediction_dict
-
-
-class ShortestPathPredictorForRailEnv(PredictionBuilder):
-    """
-    ShortestPathPredictorForRailEnv object.
-
-    This object returns shortest-path predictions for agents in the RailEnv environment.
-    The prediction acts as if no other agent is in the environment and always takes the forward action.
-    """
-
-    def __init__(self, max_depth=20):
-        # Initialize with depth 20
-        self.max_depth = max_depth
-
-    def get(self, custom_args=None, handle=None):
-        """
-        Called whenever get_many in the observation build is called.
-        Requires distance_map to extract the shortest path.
-
-        Parameters
-        ----------
-        custom_args: dict
-            - distance_map : dict
-        handle : int, optional
-            Handle of the agent for which to compute the observation vector.
-
-        Returns
-        -------
-        np.array
-            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
-            - time_offset
-            - position axis 0
-            - position axis 1
-            - direction
-            - action taken to come here
-            The prediction at 0 is the current position, direction etc.
-        """
-        agents = self.env.agents
-        if handle:
-            agents = [self.env.agents[handle]]
-        assert custom_args is not None
-        distance_map = custom_args.get('distance_map')
-        assert distance_map is not None
-
-        prediction_dict = {}
-        for agent in agents:
-            _agent_initial_position = agent.position
-            _agent_initial_direction = agent.direction
-            agent_speed = agent.speed_data["speed"]
-            times_per_cell = int(np.reciprocal(agent_speed))
-            prediction = np.zeros(shape=(self.max_depth + 1, 5))
-            prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
-            new_direction = _agent_initial_direction
-            new_position = _agent_initial_position
-            visited = set()
-            for index in range(1, self.max_depth + 1):
-                # if we're at the target, stop moving...
-                if agent.position == agent.target:
-                    prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
-                    visited.add((agent.position[0], agent.position[1], agent.direction))
-                    continue
-                if not agent.moving:
-                    prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING]
-                    visited.add((agent.position[0], agent.position[1], agent.direction))
-                    continue
-                # Take shortest possible path
-                cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
-
-                if np.sum(cell_transitions) == 1 and index % times_per_cell == 0:
-                    new_direction = np.argmax(cell_transitions)
-                    new_position = get_new_position(agent.position, new_direction)
-                elif np.sum(cell_transitions) > 1 and index % times_per_cell == 0:
-                    min_dist = np.inf
-                    no_dist_found = True
-                    for direction in range(4):
-                        if cell_transitions[direction] == 1:
-                            neighbour_cell = get_new_position(agent.position, direction)
-                            target_dist = distance_map[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
-                            if target_dist < min_dist or no_dist_found:
-                                min_dist = target_dist
-                                new_direction = direction
-                                no_dist_found = False
-                    new_position = get_new_position(agent.position, new_direction)
-                elif index % times_per_cell == 0:
-                    raise Exception("No transition possible {}".format(cell_transitions))
-
-                # update the agent's position and direction
-                agent.position = new_position
-                agent.direction = new_direction
-
-                # prediction is ready
-                prediction[index] = [index, *new_position, new_direction, 0]
-                visited.add((new_position[0], new_position[1], new_direction))
-            self.env.dev_pred_dict[agent.handle] = visited
-            prediction_dict[agent.handle] = prediction
-
-            # cleanup: reset initial position
-            agent.position = _agent_initial_position
-            agent.direction = _agent_initial_direction
-
-        return prediction_dict
diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py
index 589d610..4befcd0 100644
--- a/torch_training/render_agent_behavior.py
+++ b/torch_training/render_agent_behavior.py
@@ -101,7 +101,7 @@ dones_list = []
 action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
-agent = Agent(state_size, action_size, "FC", 0)
+agent = Agent(state_size, action_size, 0)
 with path(torch_training.Nets, "navigator_checkpoint10700.pth") as file_in:
     agent.qnetwork_local.load_state_dict(torch.load(file_in))
 
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index f69929f..80e4767 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -11,7 +11,7 @@ sys.path.append(str(base_dir))
 import matplotlib.pyplot as plt
 import numpy as np
 import torch
-from dueling_double_dqn import Agent
+from torch_training.dueling_double_dqn import Agent
 
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
-- 
GitLab