diff --git a/flatland/baselines/dueling_double_dqn.py b/flatland/baselines/dueling_double_dqn.py
index 66fe3a3effec0dfa9dc35d07fec887eaa05be6fc..41a27bf8431df7812f1b4f63e797aa426c17edf1 100644
--- a/flatland/baselines/dueling_double_dqn.py
+++ b/flatland/baselines/dueling_double_dqn.py
@@ -1,12 +1,14 @@
-import numpy as np
-import random
-from collections import namedtuple, deque
+import copy
 import os
-from flatland.baselines.model import QNetwork, QNetwork2
+import random
+from collections import namedtuple, deque, Iterable
+
+import numpy as np
 import torch
 import torch.nn.functional as F
 import torch.optim as optim
-import copy
+
+from flatland.baselines.model import QNetwork, QNetwork2
 
 BUFFER_SIZE = int(1e5)  # replay buffer size
 BATCH_SIZE = 512  # minibatch size
@@ -175,16 +177,24 @@ class ReplayBuffer:
         """Randomly sample a batch of experiences from memory."""
         experiences = random.sample(self.memory, k=self.batch_size)
 
-        states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device)
-        actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device)
-        rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)
-        next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(
-            device)
-        dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(
-            device)
+        states = torch.from_numpy(self.__v_stack_impr([e.state for e in experiences if e is not None])) \
+            .float().to(device)
+        actions = torch.from_numpy(self.__v_stack_impr([e.action for e in experiences if e is not None])) \
+            .long().to(device)
+        rewards = torch.from_numpy(self.__v_stack_impr([e.reward for e in experiences if e is not None])) \
+            .float().to(device)
+        next_states = torch.from_numpy(self.__v_stack_impr([e.next_state for e in experiences if e is not None])) \
+            .float().to(device)
+        dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \
+            .float().to(device)
 
         return (states, actions, rewards, next_states, dones)
 
     def __len__(self):
         """Return the current size of internal memory."""
         return len(self.memory)
+
+    def __v_stack_impr(self, states):
+        sub_dim = len(states[0][0]) if isinstance(states[0], Iterable) else 1
+        np_states = np.reshape(np.array(states), (len(states), sub_dim))
+        return np_states