Skip to content
Snippets Groups Projects
Commit 97b22d6f authored by u214892's avatar u214892
Browse files

#22 improved stacking np arrays; improves play_model (20 trials, 5 agents) 55s > 45s

parent 9ef07d0f
No related branches found
No related tags found
No related merge requests found
import numpy as np
import random
from collections import namedtuple, deque
import copy
import os
from flatland.baselines.model import QNetwork, QNetwork2
import random
from collections import namedtuple, deque, Iterable
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
import copy
from flatland.baselines.model import QNetwork, QNetwork2
BUFFER_SIZE = int(1e5) # replay buffer size
BATCH_SIZE = 512 # minibatch size
......@@ -175,16 +177,24 @@ class ReplayBuffer:
"""Randomly sample a batch of experiences from memory."""
experiences = random.sample(self.memory, k=self.batch_size)
states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device)
actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device)
rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)
next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(
device)
dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(
device)
states = torch.from_numpy(self.__v_stack_impr([e.state for e in experiences if e is not None])) \
.float().to(device)
actions = torch.from_numpy(self.__v_stack_impr([e.action for e in experiences if e is not None])) \
.long().to(device)
rewards = torch.from_numpy(self.__v_stack_impr([e.reward for e in experiences if e is not None])) \
.float().to(device)
next_states = torch.from_numpy(self.__v_stack_impr([e.next_state for e in experiences if e is not None])) \
.float().to(device)
dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \
.float().to(device)
return (states, actions, rewards, next_states, dones)
def __len__(self):
"""Return the current size of internal memory."""
return len(self.memory)
def __v_stack_impr(self, states):
sub_dim = len(states[0][0]) if isinstance(states[0], Iterable) else 1
np_states = np.reshape(np.array(states), (len(states), sub_dim))
return np_states
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment