Newer
Older
"""
"""
import torch.nn as nn
import torch
import torch.nn.functional as F
class FlatNet(nn.Module):
"""
Definition of a3c model, forward pass and loss.
"""
def __init__(self,size_x=20,size_y=10,init_weights=True):
super(FlatNet, self).__init__()
self.observations_= 12804# 25600 #size_x*size_y*4# 9216# 256
self.num_actions_= 5
self.loss_value_ = 0.5
self.loss_entropy_ = 0.01
self.epsilon_ = 1.e-14
self.avgpool2 = nn.AdaptiveAvgPool2d((5, 5))
nn.Conv2d(22, 128, kernel_size=3, padding=1),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
#nn.MaxPool2d(kernel_size=2, stride=2)
)
self.feat_1b = nn.Sequential(
nn.Conv2d(22, 128, kernel_size=1, padding=0),
#nn.MaxPool2d(kernel_size=2, stride=2)
)
self.feat_2 = nn.Sequential(
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
# nn.MaxPool2d(kernel_size=2, stride=2)
)
self.feat_2b = nn.Sequential(
nn.Conv2d(256, 256, kernel_size=1, padding=0),
#nn.MaxPool2d(kernel_size=2, stride=2)
self.classifier_ = nn.Sequential(
nn.Linear(self.observations_, 512),
nn.LeakyReLU(True),
nn.Dropout(),
nn.Linear(512, 512),
self.lstm_ = nn.LSTM(512, 512, 1)
nn.Linear(512, self.num_actions_),
self.value_ = nn.Sequential(nn.Linear(512,1))
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
if init_weights==True:
self._initialize_weights(self.classifier_)
self._initialize_weights(self.policy_)
self._initialize_weights(self.value_)
self._initialize_weights(self.feat_1)
self._initialize_weights(self.feat_1b)
self._initialize_weights(self.feat_2)
self._initialize_weights(self.feat_2b)
def loss(self,out_policy, y_policy, out_values, y_values):
log_prob = torch.log(torch.sum(out_policy*y_policy, 1)+self.epsilon_);
advantage = y_values - out_values
#advantage = advantage.detach()#requires_grad_(requires_grad=False)
policy_loss = -log_prob * advantage.detach()
value_loss = self.loss_value_ * advantage.pow(2)
entropy = self.loss_entropy_ * torch.sum(y_policy * torch.log(y_policy+self.epsilon_), 1)
loss = torch.mean(policy_loss + value_loss + entropy)
return loss
def features_1(self,x):
x1 = self.feat_1(x)
x1b = self.feat_1b(x)
return torch.cat([x1,x1b],1)
def features_2(self,x):
x2 = self.feat_2(x)
x2b = self.feat_2b(x)
return torch.cat([x2,x2b],1)
def features(self,x):
x = F.leaky_relu(self.features_1(x),True)
x = F.max_pool2d(x,kernel_size=2, stride=2)
x = F.leaky_relu(self.features_2(x),True)
x = F.max_pool2d(x,kernel_size=2, stride=2)
x = self.avgpool2(x)
def forward(self, x, x2):
x = self.features(x)
x = x.view(x.size(0), -1)
x2 = x2.view(x2.size(0),-1)
x = torch.cat([x,x2],1)
x,_ = self.lstm_(x.view(len(x), 1, -1))
x = x.view(len(x), -1)
p = self.policy_(x)
v = self.value_(x)
return p,v
def _initialize_weights(self,modules_):
for i,m in enumerate(modules_):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)