diff --git a/flatland/baselines/dueling_double_dqn.py b/flatland/baselines/dueling_double_dqn.py index 66fe3a3effec0dfa9dc35d07fec887eaa05be6fc..41a27bf8431df7812f1b4f63e797aa426c17edf1 100644 --- a/flatland/baselines/dueling_double_dqn.py +++ b/flatland/baselines/dueling_double_dqn.py @@ -1,12 +1,14 @@ -import numpy as np -import random -from collections import namedtuple, deque +import copy import os -from flatland.baselines.model import QNetwork, QNetwork2 +import random +from collections import namedtuple, deque, Iterable + +import numpy as np import torch import torch.nn.functional as F import torch.optim as optim -import copy + +from flatland.baselines.model import QNetwork, QNetwork2 BUFFER_SIZE = int(1e5) # replay buffer size BATCH_SIZE = 512 # minibatch size @@ -175,16 +177,24 @@ class ReplayBuffer: """Randomly sample a batch of experiences from memory.""" experiences = random.sample(self.memory, k=self.batch_size) - states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device) - actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device) - rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device) - next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to( - device) - dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to( - device) + states = torch.from_numpy(self.__v_stack_impr([e.state for e in experiences if e is not None])) \ + .float().to(device) + actions = torch.from_numpy(self.__v_stack_impr([e.action for e in experiences if e is not None])) \ + .long().to(device) + rewards = torch.from_numpy(self.__v_stack_impr([e.reward for e in experiences if e is not None])) \ + .float().to(device) + next_states = torch.from_numpy(self.__v_stack_impr([e.next_state for e in experiences if e is not None])) \ + .float().to(device) + dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \ + .float().to(device) return (states, actions, rewards, next_states, dones) def __len__(self): """Return the current size of internal memory.""" return len(self.memory) + + def __v_stack_impr(self, states): + sub_dim = len(states[0][0]) if isinstance(states[0], Iterable) else 1 + np_states = np.reshape(np.array(states), (len(states), sub_dim)) + return np_states