diff --git a/flatland/baselines/dueling_double_dqn.py b/flatland/baselines/dueling_double_dqn.py
index 66fe3a3effec0dfa9dc35d07fec887eaa05be6fc..41a27bf8431df7812f1b4f63e797aa426c17edf1 100644
--- a/flatland/baselines/dueling_double_dqn.py
+++ b/flatland/baselines/dueling_double_dqn.py
@@ -1,12 +1,14 @@
-import numpy as np
-import random
-from collections import namedtuple, deque
+import copy
 import os
-from flatland.baselines.model import QNetwork, QNetwork2
+import random
+from collections import namedtuple, deque, Iterable
+
+import numpy as np
 import torch
 import torch.nn.functional as F
 import torch.optim as optim
-import copy
+
+from flatland.baselines.model import QNetwork, QNetwork2
 
 BUFFER_SIZE = int(1e5)  # replay buffer size
 BATCH_SIZE = 512  # minibatch size
@@ -175,16 +177,24 @@ class ReplayBuffer:
         """Randomly sample a batch of experiences from memory."""
         experiences = random.sample(self.memory, k=self.batch_size)
 
-        states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device)
-        actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device)
-        rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)
-        next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(
-            device)
-        dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(
-            device)
+        states = torch.from_numpy(self.__v_stack_impr([e.state for e in experiences if e is not None])) \
+            .float().to(device)
+        actions = torch.from_numpy(self.__v_stack_impr([e.action for e in experiences if e is not None])) \
+            .long().to(device)
+        rewards = torch.from_numpy(self.__v_stack_impr([e.reward for e in experiences if e is not None])) \
+            .float().to(device)
+        next_states = torch.from_numpy(self.__v_stack_impr([e.next_state for e in experiences if e is not None])) \
+            .float().to(device)
+        dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \
+            .float().to(device)
 
         return (states, actions, rewards, next_states, dones)
 
     def __len__(self):
         """Return the current size of internal memory."""
         return len(self.memory)
+
+    def __v_stack_impr(self, states):
+        sub_dim = len(states[0][0]) if isinstance(states[0], Iterable) else 1
+        np_states = np.reshape(np.array(states), (len(states), sub_dim))
+        return np_states
diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py
index 622d900598bba6bbc48750d6bb48923975af9b5e..add047b6c7895e391211258bb10561110e0f1a19 100644
--- a/flatland/core/transitions.py
+++ b/flatland/core/transitions.py
@@ -556,16 +556,16 @@ class RailEnvTransitions(Grid4Transitions):
         self.maskDeadEnds = 0b0010000110000100
 
         # create this to make validation faster
-        self.transitions_all = []
+        self.transitions_all = set()
         for index, trans in enumerate(self.transitions):
-            self.transitions_all.append(trans)
+            self.transitions_all.add(trans)
             if index in (2, 4, 6, 7, 8, 9, 10):
                 for _ in range(3):
                     trans = self.rotate_transition(trans, rotation=90)
-                    self.transitions_all.append(trans)
+                    self.transitions_all.add(trans)
             elif index in (1, 5):
                 trans = self.rotate_transition(trans, rotation=90)
-                self.transitions_all.append(trans)
+                self.transitions_all.add(trans)
 
     def print(self, cell_transition):
         print("  NESW")
@@ -620,10 +620,7 @@ class RailEnvTransitions(Grid4Transitions):
         Boolean
             True or False
         """
-        for trans in self.transitions_all:
-            if cell_transition == trans:
-                return True
-        return False
+        return cell_transition in self.transitions_all
 
     def has_deadend(self, cell_transition):
         if cell_transition & self.maskDeadEnds > 0:
diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py
index b58604c6d7ededa28a33d30e87e13777a3cd54ec..1482b4438bebd82638b873f3232198172a05e6d0 100644
--- a/flatland/envs/env_utils.py
+++ b/flatland/envs/env_utils.py
@@ -1,12 +1,13 @@
-
 """
 Definition of the RailEnv environment and related level-generation functions.
 
 Generator functions are functions that take width, height and num_resets as arguments and return
 a GridTransitionMap object.
 """
+
 import numpy as np
 
+
 # from flatland.core.env import Environment
 # from flatland.envs.observations import TreeObsForRailEnv
 
@@ -53,7 +54,6 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p
         else:
             # check if matches existing layout
             new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
-            # new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
     else:
         # set the forward path
         new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
@@ -68,7 +68,6 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p
         else:
             # check if matches existing layout
             new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
-            # new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1)
 
         if not rail_trans.is_valid(new_trans_e):
             return False
@@ -90,6 +89,9 @@ class AStarNode():
     def __eq__(self, other):
         return self.pos == other.pos
 
+    def __hash__(self):
+        return hash(self.pos)
+
     def update_if_better(self, other):
         if other.g < self.g:
             self.parent = other.parent
@@ -106,30 +108,23 @@ def a_star(rail_trans, rail_array, start, end):
     rail_shape = rail_array.shape
     start_node = AStarNode(None, start)
     end_node = AStarNode(None, end)
-    open_list = []
-    closed_list = []
+    open_nodes = set()
+    closed_nodes = set()
+    open_nodes.add(start_node)
 
-    open_list.append(start_node)
-
-    # this could be optimized
-    def is_node_in_list(node, the_list):
-        for o_node in the_list:
-            if node == o_node:
-                return o_node
-        return None
-
-    while len(open_list) > 0:
+    while len(open_nodes) > 0:
         # get node with current shortest est. path (lowest f)
-        current_node = open_list[0]
-        current_index = 0
-        for index, item in enumerate(open_list):
+        current_node = None
+        for item in open_nodes:
+            if current_node is None:
+                current_node = item
+                continue
             if item.f < current_node.f:
                 current_node = item
-                current_index = index
 
         # pop current off open list, add to closed list
-        open_list.pop(current_index)
-        closed_list.append(current_node)
+        open_nodes.remove(current_node)
+        closed_nodes.add(current_node)
 
         # found the goal
         if current_node == end_node:
@@ -149,10 +144,7 @@ def a_star(rail_trans, rail_array, start, end):
             prev_pos = None
         for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
             node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1])
-            if node_pos[0] >= rail_shape[0] or \
-                    node_pos[0] < 0 or \
-                    node_pos[1] >= rail_shape[1] or \
-                    node_pos[1] < 0:
+            if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[1] >= rail_shape[1] or node_pos[1] < 0:
                 continue
 
             # validate positions
@@ -166,8 +158,7 @@ def a_star(rail_trans, rail_array, start, end):
         # loop through children
         for child in children:
             # already in closed list?
-            closed_node = is_node_in_list(child, closed_list)
-            if closed_node is not None:
+            if child in closed_nodes:
                 continue
 
             # create the f, g, and h values
@@ -180,16 +171,14 @@ def a_star(rail_trans, rail_array, start, end):
             child.f = child.g + child.h
 
             # already in the open list?
-            open_node = is_node_in_list(child, open_list)
-            if open_node is not None:
-                open_node.update_if_better(child)
+            if child in open_nodes:
                 continue
 
             # add the child to the open list
-            open_list.append(child)
+            open_nodes.add(child)
 
         # no full path found
-        if len(open_list) == 0:
+        if len(open_nodes) == 0:
             return []
 
 
@@ -323,8 +312,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
             valid_starting_directions = []
             for m in valid_movements:
                 new_position = get_new_position(agents_position[i], m[1])
-                if m[0] not in valid_starting_directions and \
-                   _path_exists(rail, new_position, m[0], agents_target[i]):
+                if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0], agents_target[i]):
                     valid_starting_directions.append(m[0])
 
             if len(valid_starting_directions) == 0: