Skip to content
Snippets Groups Projects
model.py 4.45 KiB
Newer Older
"""

"""
import torch.nn as nn
import torch

import torch.nn.functional as F
class FlatNet(nn.Module):
    """
    Definition of a3c model, forward pass and loss.
    """
    def __init__(self,size_x=20,size_y=10,init_weights=True):
        super(FlatNet, self).__init__()
        self.observations_= 12804# 25600 #size_x*size_y*4# 9216# 256
        self.num_actions_= 5
        self.loss_value_ = 0.5
        self.loss_entropy_ = 0.01
        self.epsilon_ = 1.e-14
 
        self.avgpool2 = nn.AdaptiveAvgPool2d((5, 5))
        self.feat_1 = nn.Sequential(
            nn.Conv2d(22, 128, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            #nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.feat_1b = nn.Sequential(
            nn.Conv2d(22, 128, kernel_size=1, padding=0),
            #nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.feat_2 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
           # nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.feat_2b = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=1, padding=0),
            #nn.MaxPool2d(kernel_size=2, stride=2)
        self.classifier_ = nn.Sequential( 
            nn.Linear(self.observations_, 512),
            nn.LeakyReLU(True),
            nn.Dropout(),
            nn.LeakyReLU(True),
        )

        self.lstm_ = nn.LSTM(512, 512, 1) 
        self.policy_ = nn.Sequential(
            nn.Linear(512, self.num_actions_),
        self.value_ = nn.Sequential(nn.Linear(512,1))


        if init_weights==True:
            self._initialize_weights(self.classifier_)
            self._initialize_weights(self.policy_)
            self._initialize_weights(self.value_)
            self._initialize_weights(self.feat_1)
            self._initialize_weights(self.feat_1b)
            self._initialize_weights(self.feat_2)
            self._initialize_weights(self.feat_2b)


    def loss(self,out_policy, y_policy, out_values, y_values):
        log_prob = torch.log(torch.sum(out_policy*y_policy, 1)+self.epsilon_); 
        advantage = y_values - out_values
        #advantage = advantage.detach()#requires_grad_(requires_grad=False)
        policy_loss = -log_prob * advantage.detach()
        value_loss = self.loss_value_ * advantage.pow(2)
        entropy = self.loss_entropy_ * torch.sum(y_policy * torch.log(y_policy+self.epsilon_), 1)
        loss = torch.mean(policy_loss + value_loss + entropy)
        return  loss
    
    def features_1(self,x):
        x1 = self.feat_1(x)
        x1b = self.feat_1b(x)
        return torch.cat([x1,x1b],1)

    def features_2(self,x):
        x2 = self.feat_2(x)
        x2b = self.feat_2b(x)
        return torch.cat([x2,x2b],1)
 
    def features(self,x):
        x = F.leaky_relu(self.features_1(x),True)
        x = F.max_pool2d(x,kernel_size=2, stride=2)
        x = F.leaky_relu(self.features_2(x),True)
        x = F.max_pool2d(x,kernel_size=2, stride=2)
        x = self.avgpool2(x)
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x2 = x2.view(x2.size(0),-1)
        x = torch.cat([x,x2],1)
        x = self.classifier_(x)
        x,_ = self.lstm_(x.view(len(x), 1, -1))
        x = x.view(len(x), -1)
        p = self.policy_(x)
        v = self.value_(x)
        return p,v

    def _initialize_weights(self,modules_):
        for i,m in enumerate(modules_):
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)