diff --git a/checkpoints/ppo/model_checkpoint.meta b/checkpoints/ppo/model_checkpoint.meta
index 7617876cf3d7031f066a779fde687404b0a1cc6f..b45d9b8937c572df6febfa2f0ac5a9d4cda4eb0e 100644
Binary files a/checkpoints/ppo/model_checkpoint.meta and b/checkpoints/ppo/model_checkpoint.meta differ
diff --git a/checkpoints/ppo/model_checkpoint.optimizer b/checkpoints/ppo/model_checkpoint.optimizer
index b93d28155360433cbb1574b2e797bf1e293c2f6c..190ef25976343f4c1cca9b751f78fc8fdcadfa28 100644
Binary files a/checkpoints/ppo/model_checkpoint.optimizer and b/checkpoints/ppo/model_checkpoint.optimizer differ
diff --git a/checkpoints/ppo/model_checkpoint.policy b/checkpoints/ppo/model_checkpoint.policy
index bc21bc40897b530a65966b1cbbbaeb41835f7b69..c4492df60aaec91709c87ae729bf71480866b31e 100644
Binary files a/checkpoints/ppo/model_checkpoint.policy and b/checkpoints/ppo/model_checkpoint.policy differ
diff --git a/src/extra.py b/src/extra.py
index 84f20d9def1cad022f90abacadd646ecb34b04dc..025c8f1355f5c4c1122a3b5272049234e53c71da 100644
--- a/src/extra.py
+++ b/src/extra.py
@@ -67,19 +67,9 @@ class Extra(ObservationBuilder):
 
     def __init__(self, max_depth):
         self.max_depth = max_depth
-        self.observation_dim = 22
+        self.observation_dim = 30
         self.agent = None
 
-    def loadAgent(self):
-        if self.agent is not None:
-            return
-        self.state_size = self.env.obs_builder.observation_dim
-        self.action_size = 5
-        print("action_size: ", self.action_size)
-        print("state_size: ", self.state_size)
-        self.agent = Agent(self.state_size, self.action_size, 0)
-        self.agent.load('./checkpoints/', 0, 1.0)
-
     def build_data(self):
         if self.env is not None:
             self.env.dev_obs_dict = {}
@@ -197,6 +187,9 @@ class Extra(ObservationBuilder):
     def normalize_observation(self, obsData):
         return obsData
 
+    def is_collision(self, obsData):
+        return False
+
     def reset(self):
         self.build_data()
         return
@@ -210,15 +203,16 @@ class Extra(ObservationBuilder):
             return 2
         return 3
 
-    def _explore(self, handle, new_position, new_direction, depth=0):
-
+    def _explore(self, handle, distance_map, new_position, new_direction, depth=0):
         has_opp_agent = 0
         has_same_agent = 0
         visited = []
+        visited_direction = []
+        visited_min_distance = np.inf
 
         # stop exploring (max_depth reached)
         if depth >= self.max_depth:
-            return has_opp_agent, has_same_agent, visited
+            return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance
 
         # max_explore_steps = 100
         cnt = 0
@@ -226,15 +220,22 @@ class Extra(ObservationBuilder):
             cnt += 1
 
             visited.append(new_position)
+            visited_direction.append(new_direction)
+
+            new_cell_dist = distance_map[handle,
+                                         new_position[0], new_position[1],
+                                         new_direction]
+            visited_min_distance = min(visited_min_distance, new_cell_dist)
+
             opp_a = self.env.agent_positions[new_position]
             if opp_a != -1 and opp_a != handle:
                 if self.env.agents[opp_a].direction != new_direction:
                     # opp agent found
                     has_opp_agent = 1
-                    return has_opp_agent, has_same_agent, visited
+                    return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance
                 else:
                     has_same_agent = 1
-                    return has_opp_agent, has_same_agent, visited
+                    return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance
 
             # convert one-hot encoding to 0,1,2,3
             possible_transitions = self.env.rail.get_transitions(*new_position, new_direction)
@@ -243,20 +244,28 @@ class Extra(ObservationBuilder):
             agents_near_to_switch_all = \
                 self.check_agent_descision(new_position, new_direction)
             if agents_near_to_switch:
-                return has_opp_agent, has_same_agent, visited
+                return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance
 
             if agents_on_switch:
                 for dir_loop in range(4):
                     if possible_transitions[dir_loop] == 1:
-                        hoa, hsa, v = self._explore(handle, new_position, new_direction, depth + 1)
-                        visited.append(v)
+                        hoa, hsa, v, d, min_dist = self._explore(handle,
+                                                                 distance_map,
+                                                                 get_new_position(new_position, dir_loop),
+                                                                 dir_loop,
+                                                                 depth + 1)
+                        if np.math.isinf(min_dist) == False:
+                            visited_min_distance = min(visited_min_distance, min_dist)
+
+                        visited = visited + v
+                        visited_direction = visited_direction + d
                         has_opp_agent = 0.5 * (has_opp_agent + hoa)
                         has_same_agent = 0.5 * (has_same_agent + hsa)
-                return has_opp_agent, has_same_agent, visited
+                return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance
             else:
                 new_direction = fast_argmax(possible_transitions)
                 new_position = get_new_position(new_position, new_direction)
-        return has_opp_agent, has_same_agent, visited
+        return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance
 
     def get(self, handle):
         # all values are [0,1]
@@ -285,6 +294,7 @@ class Extra(ObservationBuilder):
 
         observation = np.zeros(self.observation_dim)
         visited = []
+        visited_direction = []
         agent = self.env.agents[handle]
 
         agent_done = False
@@ -301,6 +311,7 @@ class Extra(ObservationBuilder):
 
         if not agent_done:
             visited.append(agent_virtual_position)
+            visited_direction.append(agent.direction)
             distance_map = self.env.distance_map.get()
             current_cell_dist = distance_map[handle,
                                              agent_virtual_position[0], agent_virtual_position[1],
@@ -319,8 +330,12 @@ class Extra(ObservationBuilder):
                     if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)):
                         observation[dir_loop] = int(new_cell_dist < current_cell_dist)
 
-                    has_opp_agent, has_same_agent, v = self._explore(handle, new_position, branch_direction)
-                    visited.append(v)
+                    has_opp_agent, has_same_agent, vis, dir, min_dist = self._explore(handle,
+                                                                                      distance_map,
+                                                                                      new_position,
+                                                                                      branch_direction)
+                    visited = visited + vis
+                    visited_direction = visited_direction + dir
 
                     observation[10 + dir_loop] = 1
                     observation[14 + dir_loop] = has_opp_agent
@@ -334,6 +349,16 @@ class Extra(ObservationBuilder):
         observation[8] = int(agents_near_to_switch)
         observation[9] = int(agents_near_to_switch_all)
 
+        observation[22] = int(self.env._elapsed_steps % 4 == 0)
+        observation[23] = int(self.env._elapsed_steps % 4 == 1)
+        observation[24] = int(self.env._elapsed_steps % 4 == 2)
+        observation[25] = int(self.env._elapsed_steps % 4 == 3)
+
+        observation[26] = int(agent.direction % 4 == 0)
+        observation[27] = int(agent.direction % 4 == 1)
+        observation[28] = int(agent.direction % 4 == 2)
+        observation[29] = int(agent.direction % 4 == 3)
+
         self.env.dev_obs_dict.update({handle: visited})
 
         return observation
@@ -349,3 +374,13 @@ class Extra(ObservationBuilder):
                 action_dict[a] = RailEnvActions.DO_NOTHING
 
         return action_dict
+
+    def loadAgent(self):
+        if self.agent is not None:
+            return
+        self.state_size = self.env.obs_builder.observation_dim
+        self.action_size = 5
+        print("action_size: ", self.action_size)
+        print("state_size: ", self.state_size)
+        self.agent = Agent(self.state_size, self.action_size, 0)
+        self.agent.load('./checkpoints/', 0, 1.0)
\ No newline at end of file
diff --git a/src/ppo/model.py b/src/ppo/model.py
index 51b86ff16691c03f6a754405352bb4cf48e4b914..421423df6739bbc4b4ed94487de7e3dfa9d973a8 100644
--- a/src/ppo/model.py
+++ b/src/ppo/model.py
@@ -3,7 +3,7 @@ import torch.nn.functional as F
 
 
 class PolicyNetwork(nn.Module):
-    def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128, hidsize3=32):
+    def __init__(self, state_size, action_size, hidsize1=128, hidsize2=256, hidsize3=32):
         super().__init__()
         self.fc1 = nn.Linear(state_size, hidsize1)
         self.fc2 = nn.Linear(hidsize1, hidsize2)