Forked from
Flatland / Flatland
2723 commits behind the upstream repository.
-
Erik Nygren authoredErik Nygren authored
model.py 2.10 KiB
import torch
import torch.nn as nn
import torch.nn.functional as F
class QNetwork(nn.Module):
def __init__(self, state_size, action_size, seed, hidsize1=128, hidsize2=128):
super(QNetwork, self).__init__()
self.fc1_val = nn.Linear(state_size, hidsize1)
self.fc2_val = nn.Linear(hidsize1, hidsize2)
self.fc3_val = nn.Linear(hidsize2, 1)
self.fc1_adv = nn.Linear(state_size, hidsize1)
self.fc2_adv = nn.Linear(hidsize1, hidsize2)
self.fc3_adv = nn.Linear(hidsize2, action_size)
def forward(self, x):
val = F.relu(self.fc1_val(x))
val = F.relu(self.fc2_val(val))
val = self.fc3_val(val)
# advantage calculation
adv = F.relu(self.fc1_adv(x))
adv = F.relu(self.fc2_adv(adv))
adv = self.fc3_adv(adv)
return val + adv - adv.mean()
class QNetwork2(nn.Module):
def __init__(self, state_size, action_size, seed, input_channels, hidsize1=128, hidsize2=64):
super(QNetwork2, self).__init__()
self.conv1 = nn.Conv2d(input_channels, 16, kernel_size=3, stride=1)
self.bn1 = nn.BatchNorm2d(16)
self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=3)
self.bn2 = nn.BatchNorm2d(32)
self.conv3 = nn.Conv2d(32, 64, kernel_size=5, stride=3)
self.bn3 = nn.BatchNorm2d(64)
self.fc1_val = nn.Linear(6400, hidsize1)
self.fc2_val = nn.Linear(hidsize1, hidsize2)
self.fc3_val = nn.Linear(hidsize2, 1)
self.fc1_adv = nn.Linear(6400, hidsize1)
self.fc2_adv = nn.Linear(hidsize1, hidsize2)
self.fc3_adv = nn.Linear(hidsize2, action_size)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
# value function approximation
val = F.relu(self.fc1_val(x.view(x.size(0), -1)))
val = F.relu(self.fc2_val(val))
val = self.fc3_val(val)
# advantage calculation
adv = F.relu(self.fc1_adv(x.view(x.size(0), -1)))
adv = F.relu(self.fc2_adv(adv))
adv = self.fc3_adv(adv)
return val + adv - adv.mean()