Skip to content
Snippets Groups Projects
Commit faed438e authored by stephan's avatar stephan
Browse files

First prototype of a3c baseline - experimental!

parent b40798ff
No related branches found
No related tags found
No related merge requests found
Experimental a3c implementation
===============================
Dependecies
-----------
- pytorch 1.0 with gpu support
- numpy
Installation
------------
1.) Put all files in this folder into same folder as flatland
**OR**
2.) Add path to flatland in "main.py" to sys.path
Running
-------
* Training: python3 main.py generate
* Replay: python3 main.py test
Settings
--------
**PLEASE NOTE: THIS CODE IS EXPERIMENTAL AND AS IT IS!**
Tested using complex rail generator with 5 Agents and 4o extra connections. The gridworld is 32x32, local observation is a cropped region of 8x8 out of the gridworld. Observation consist 3 temporal step
* Transitions: 8x8 tensor float
* Positions ( all agents): 8x8 tensor float
* Position (agent x): 8x8 tensor float
* Target (agent x): 8x8 tensor float
Fusion of 3 temporal steps -> state has size 3x4x8x8
The grid-size and the region-size are hardcoded in main.py!
Hyperparameters
---------------
FlatLink.n_step_return_ = 50 , length of n-step return
FlatLink.gamma_ = 0.5, discount factor
FlatLink.num_epochs = 20, epochs used for training model
FlatLink.play_epochs = 50, overll number of training epochs (number of sequences play, train-model)
FlatLink.thres = 2.00, threshold for choose random or predicted action
FlatLink.max_iter = 300, maximum number steps allowed per game
FlatLink.play_game_number = 10, number of games played per training epoch
FlatLink.rate_decay = 0.97, decay rate for thres per epoch
FlatLink.replay_buffer_size = 4000, replay buffer size (balanced buffer)
FlatLink.min_thres = 0.05, minimum value for thres
FlatNet.loss_value_ = 0.01, weight for value loss
FlatNet.loss_entropy_ = 0.05, weight for entropy loss
Remarks
-------
Convergation rate is very slow, no proper training weigths available at this time.
import numpy as np
class fakeEnv():
"""
Dummy environment for mock test of a3c.
"""
def __init__(self):
self.grid_size = tuple([20,10])
self.num_agents = 2
self.done = False
self.iterate = 0
self.grid = np.zeros(self.grid_size)
def reset(self):
self.done = False
self.iterate = 0
grid_ = self.grid
return grid_
def step(self, action):
grid_ = self.grid
fake_rail = np.random.rand(self.grid_size[0],self.grid_size[1])
grid_[fake_rail<0.25]=0.25
grid_[fake_rail>0.45]=0.5
reward = np.array(np.mean(fake_rail)/self.grid_size[0]/self.grid_size[1])
self.iterate+=1
return grid_,reward,self.iterate > 100, None
import numpy as np
from model import FlatNet
from loader import FlatData
import torch
import torch.nn as nn
import torch.optim as optim
from scipy import stats
import queue
import random
class FlatLink():
"""
Linker for environment.
Training: Generate training data, accumulate n-step reward and train model.
Play: Execute trained model.
"""
def __init__(self,size_x=20,size_y=10, gpu_ = 0):
self.n_step_return_ = 50
self.gamma_ = 0.5
self.batch_size = 1000
self.num_epochs = 20
self.play_epochs = 50
self.a = 3.57
self.thres = 2.00
self.alpha_rv = stats.alpha(self.a)
self.alpha_rv.random_state = np.random.RandomState(seed=342423)
self.model = FlatNet(size_x,size_y)
self.queue = None
self.max_iter = 300
self.dtype = torch.float
self.play_game_number = 10
self.rate_decay = 0.97
self.size_x=size_x
self.size_y=size_y
self.replay_buffer_size = 4000
self.buffer_multiplicator = 1
self.best_loss = 100000.0
self.min_thres = 0.05
if gpu_>-1:
self.cuda = True
self.device = torch.device('cuda', gpu_)
self.model.cuda()
else:
self.cuda = False
self.device = torch.device('cpu', 0)
def load_(self, file='model_a3c_i.pt'):
self.model = torch.load(file)
print(self.model)
def test_random_move(self):
move = np.argmax(self.alpha_rv.rvs(size=5))
print("move ",move)
def get_action(self, policy):
"""
Return either predicted (from policy) or random action
"""
if np.abs(stats.norm(1).rvs()) < self.thres:
move = np.argmax(self.alpha_rv.rvs(size=len(policy)))
return move, 0
else:
xk = np.arange(len(policy))
custm = stats.rv_discrete(name='custm', values=(xk, policy))
return custm.rvs(), 1
def accumulate(self, memory, state_value):
"""
memory has form (state, action, reward, after_state, done)
Accumulate b-step reward and put to training-queue
"""
n = min(len(memory), self.n_step_return_)
curr_state = memory[0][0]
action =memory[0][1]
n_step_return = 0
for i in range(n):
reward = memory[i][2]
n_step_return += np.power(self.gamma_,i+1) * reward
done = memory[-1][4]
if not done:
n_step_return += (-1)*np.power(self.gamma_,n+1)*np.abs(state_value)[0]
else:
n_step_return += np.abs(state_value)[0] # Not neccessary!
if len(memory)==1:
memory = []
else:
memory = memory[1:-1]
n_step_return = np.clip(n_step_return, -10., 1.)
return curr_state, action, n_step_return , memory
def training(self,observations,targets,rewards):
"""
Run model training on accumulated training experience
"""
self.model.train()
data_train = FlatData(observations,targets,rewards,self.model.num_actions_,self.device, self.dtype)
dataset_sizes = len(data_train)
dataloader = torch.utils.data.DataLoader(data_train, batch_size=self.batch_size, shuffle=False,num_workers=0)
optimizer= optim.Adam(self.model.parameters(),lr=0.0001,weight_decay=0.02)
# TODO early stop
for epoch in range(self.num_epochs):
running_loss = 0.0
for observation, target, reward in dataloader:
optimizer.zero_grad()
policy, value = self.model.forward(observation)
loss = self.model.loss(policy,target,value,reward)
loss.backward()
optimizer.step()
running_loss += torch.abs(loss).item() * observation.size(0)
epoch_loss = running_loss / dataset_sizes
print('Epoch {}/{}, Loss {}'.format(epoch, self.num_epochs - 1,epoch_loss))
if epoch_loss < self.best_loss:
self.best_loss = epoch_loss
try:
torch.save(self.model, 'model_a3c_i.pt')
except:
print("Model is not saved!!!")
def predict(self, observation):
"""
Forward pass on model
Returns: Policy, Value
"""
val = torch.from_numpy(observation.astype(np.float)).to(device = self.device, dtype=self.dtype)
p, v = self.model.forward(val)
return p.cpu().detach().numpy().reshape(-1), v.cpu().detach().numpy().reshape(-1)
def trainer(self,buffer_list = []):
"""
Call training sequence if buffer not emtpy
"""
if buffer_list==[]:
return
curr_state_, action, n_step_return = zip(*buffer_list)
self.training( curr_state_, action, n_step_return)
#self.queue.queue.clear()
try:
torch.save(self.model, 'model_a3c.pt')
except:
print("Model is not saved!!!")
class FlatConvert(FlatLink):
def __init__(self,size_x,size_y, gpu_ = 0):
super(FlatConvert, self).__init__(size_x,size_y,gpu_)
def perform_training(self,env,env_renderer=None):
"""
Run simulation in training mode
"""
self.queue = queue.Queue()
buffer_list=[]
self.load_('model_a3c.pt')
print("start training...")
for j in range(self.play_epochs):
reward_mean= 0.0
predicted_mean = 0.0
steps_mean = 0.0
for i in range(self.play_game_number):
predicted_, reward_ ,steps_= self.collect_training_data(env,env_renderer)
reward_mean+=reward_
predicted_mean+=predicted_
steps_mean += steps_
list_ = list(self.queue.queue)
buffer_list = buffer_list + list_ #+ [ [x,y,z] for x,y,z,d in random.sample(list_,count_)]
reward_mean /= float(self.play_game_number)
predicted_mean /= float(self.play_game_number)
steps_mean /= float(self.play_game_number)
print("play epoch: ",j,", total buffer size: ", len(buffer_list))
if len(buffer_list) > self.replay_buffer_size * 1.2:
list_2 =[ [x,y,z] for x,y,z in sorted(buffer_list, key=lambda pair: pair[2],reverse=True)]
size_ = int(self.replay_buffer_size/4)
buffer_list = random.sample(buffer_list,size_*2) + list_2[:size_] + list_2[-size_:]
random.shuffle(buffer_list)
print("play epoch: ",j,", buffer size: ", len(buffer_list),", reward (mean): ",reward_mean,
", steps (mean): ",steps_mean,", predicted (mean): ",predicted_mean)
self.trainer(buffer_list)
self.queue.queue.clear()
self.thres = max( self.thres*self.rate_decay,self.min_thres)
def make_state(self,state):
"""
Stack states 3 times for initial call (t1==t2==t3)
"""
return np.stack((state,state,state))
def update_state(self,states_,state):
"""
update state by removing oldest timestep and adding actual step
"""
states_[:(states_.shape[0]-1),:,:,:] = states_[1:,:,:,:]
states_[(state.shape[0]-1):,:,:,:] = state.reshape((1,state.shape[0],state.shape[1],state.shape[2]))
return states_
def init_state_dict(self,state_dict_):
for key in state_dict_.keys():
state_dict_[key] = self.make_state(state_dict_[key])
return state_dict_
def update_state_dict(self,old_state_dict, new_state_dict):
for key in old_state_dict.keys():
new_state_dict[key] = self.update_state(old_state_dict[key],new_state_dict[key])
return new_state_dict
def play(self,env,env_renderer=None,filename = 'model_a3c.pt',epochs_ =10):
"""
Run simulation on trained model
"""
episode_reward = 0
self.thres = 0.0
self.load_(filename)
self.model.eval()
global_reward = 0
for i in range(epochs_):
policy_dict = {}
value_dict = {}
action_dict = {}
#self.max_iter = 40
iter_ = 0
pre_ = 0
act_ = 0
done = False
curr_state_dict = self.init_state_dict(env.reset())
#curr_state_dict = env.reset()
while not done and iter_< self.max_iter:
iter_+=1
for agent_id, state in curr_state_dict.items():
policy, value = self.predict(state.astype(np.float).reshape((1,state.shape[0],state.shape[1],state.shape[2],state.shape[3])))
policy_dict[agent_id] = policy
value_dict[agent_id] = value
action_dict[agent_id], pre = self.get_action(policy)
pre_ += pre
act_+=1
if env_renderer is not None:
env_renderer.renderEnv(show=True)
next_state_dict, reward_dict, done_dict, _ = env.step(action_dict)
next_state_dict = self.update_state_dict(curr_state_dict, next_state_dict)
done = done_dict['__all__']
curr_state_dict = next_state_dict
return float(pre_)/float(act_), global_reward, float(iter_)/float(self.max_iter)
def collect_training_data(self,env,env_renderer = None):
"""
Run single simualtion szenario for training
"""
done = False
curr_state_dict = self.init_state_dict(env.reset())
#curr_state_dict = env.reset()
episode_reward = 0
memory={}
for agent_id, state in curr_state_dict.items():
memory[agent_id] = []
self.model.eval()
policy_dict = {}
value_dict = {}
action_dict = {}
iter_ = 0
pre_ = 0
act_ = 0
global_reward = 0.0
while not done and iter_< self.max_iter:
iter_+=1
for agent_id, state in curr_state_dict.items():
policy, value = self.predict(state.astype(np.float).reshape((1,state.shape[0],state.shape[1],state.shape[2],state.shape[3])))
policy_dict[agent_id] = policy
value_dict[agent_id] = value
action_dict[agent_id], pre = self.get_action(policy)
pre_ += pre
act_+=1
if env_renderer is not None:
env_renderer.renderEnv(show=True)
next_state_dict, reward_dict, done_dict, _ = env.step(action_dict)
next_state_dict = self.update_state_dict(curr_state_dict, next_state_dict)
#reward_dict = self.modify_reward(reward_dict,done_dict,next_state_dict)
done = done_dict['__all__']
for agent_id, state in curr_state_dict.items():
memory[agent_id].append(tuple([curr_state_dict[agent_id],
action_dict[agent_id], reward_dict[agent_id], next_state_dict[agent_id], done_dict[agent_id]]))
global_reward += reward_dict[agent_id]
while ((len(memory[agent_id]) >= self.n_step_return_) or (done and not memory[agent_id] == [])):
curr_state, action, n_step_return, memory[agent_id] = self.accumulate(memory[agent_id],
value_dict[agent_id])
self.queue.put([curr_state, action, n_step_return])
curr_state_dict = next_state_dict
return float(pre_)/float(act_), global_reward, float(iter_)/float(self.max_iter)
\ No newline at end of file
import torch
import torch.nn as nn
from torchvision import utils
import numpy as np
class FlatData(torch.utils.data.dataset.Dataset):
"""
Prepare data to read batch-wise, create dataset
"""
def __init__(self,states,targets,rewards,num_action,device,dtype):
self.states_ = [torch.from_numpy(x.astype(np.float)).to(device = device, dtype=dtype) for x in states]
self.targets_ = [torch.from_numpy(np.eye(num_action)[x]).to(device = device, dtype=dtype) for x in targets]
self.rewards_ = [torch.from_numpy(np.array([x]).reshape(-1)).to(device = device, dtype=dtype) for x in rewards]
# Override to give PyTorch access to any image on the dataset
def __getitem__(self, index):
return self.states_[index], self.targets_[index], self.rewards_[index]
# Override to give PyTorch size of dataset
def __len__(self):
return len(self.states_)
import argparse
from linker import FlatLink, FlatConvert
from fake import fakeEnv
import sys
import copy
import numpy as np
sys.path.append("D:\\ZHAW Master\\FS19\VT2\\flatland-master_6")
#sys.path.insert(0, "D:\\ZHAW Master\\FS19\VT2\\flatland-master\\flatland")
from flatland.envs.rail_env import RailEnv
from flatland.envs.generators import random_rail_generator, complex_rail_generator,rail_from_GridTransitionMap_generator
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.transition_map import GridTransitionMap
from flatland.utils.rendertools import RenderTool
import time
import pickle
size_x=8
size_y=8
def save_map(rail):
np.save("my_rail_grid.npy", rail.grid)
np.save("my_rail_trans_list.npy",np.array(rail.transitions.transition_list))
np.save("my_rail_trans.npy",np.array(rail.transitions.transitions))
def load_map():
lmap = GridTransitionMap(12,12)
lmap.grid = np.load("grid/my_rail_grid.npy")
lmap.transitions.transition_list =list( np.load("grid/my_rail_trans_list.npy"))
lmap.transitions.transitions = list(np.load("grid/my_rail_trans.npy"))
return lmap
"""
def generate_map(rail_generator=random_rail_generator()):
t = time.time()
geny = rail_generator
rail = geny(size_x,size_y)
elapsed = time.time() - t
print("initialization time: ", elapsed)
return rail
"""
class RawObservation(ObservationBuilder):
"""
ObservationBuilder for raw observations.
"""
def __init__(self, size_):
self.reset()
self.size_ = size_
def _set_env(self, env):
self.env = env
def reset(self):
"""
Called after each environment reset.
"""
self.map_ = None
self.agent_positions_ = None
self.agent_handles_ = None
def search_grid(self,x_indices,y_indices):
b_ = None
if x_indices!=[] and y_indices!=[]:
for i in x_indices:
for j in y_indices:
if int(self.env.rail.grid[i][j]) != 0:
b_ = [i,j]
break
return b_
def get_nearest_in_scope(self,position_,size_,target_):
"""
Crop region of size x,y around position and set a intemediate target if
real target is not in local window.
"""
shape_ = self.env.rail.grid.shape
x_ = position_[0] - target_[0]
y_ = position_[1] - target_[1]
a_ = None
b_ = None
if np.abs(x_) >= size_[0] or np.abs(y_) >= size_[1]:
xd_ = x_
yd_ = y_
if np.abs(x_) >= size_[0]:
xd_ = int(xd_ * size_[0]/np.abs(x_))
if np.abs(y_) >= size_[1]:
yd_ = int(yd_ * size_[1]/np.abs(y_))
#bug here!
x_indices = []
y_indices = []
if xd_ < 0:
x_indices = list(reversed(range(position_[0] ,position_[0] - xd_ -1)))
elif xd_ > 0:
x_indices = list(range(position_[0] -xd_ + 1,position_[0]))
if yd_ < 0:
y_indices = list(reversed(range(position_[1] ,position_[1] - yd_ -1)))
elif yd_ > 0:
y_indices = list(range(position_[1] -yd_ +1 ,position_[1]))
# cheat
x_indices.append(position_[0])
y_indices.append(position_[1])
# "nearest point in new area"
b_ = self.search_grid(x_indices, y_indices)
if b_ is not None:
target_ = b_
else:
# take any existing point
if len (x_indices) < size_[0]/2 -1:
x_indices = list(range(position_[0] - size_x, position_[0]+size_x))
if len (y_indices) < size_[y]/2 -1:
y_indices = list(range(position_[1] - size_y,position_[1]+size_y))
b_ = self.search_grid(x_indices, y_indices)
if b_ is None:
target_ = position_
x_ = position_[0] - target_[0]
y_ = position_[1] - target_[1]
if x_ > 0 and y_ > 0:
a_ = [target_[0], target_[1]]
elif x_<= 0 and y_ <= 0:
a_ = [position_[0], position_[1]]
elif x_> 0 and y_ <= 0:
a_ = [target_[0], position_[1]]
elif x_<= 0 and y_ > 0:
a_ = [position_[0], target_[1]]
if a_[0] + size_[0] >= shape_[0]:
a_[0]-= a_[0] + size_[0] - shape_[0]
if a_[1] + size_[1] >= shape_[1]:
a_[1]-= a_[1] +size_[1] - shape_[1]
if position_[0]-a_[0] >= size_[0] or position_[1]-a_[1] >= size_[1]:
b_ = target_
return a_, target_
def get(self, handle=0):
"""
Called whenever an observation has to be computed for the `env' environment, possibly
for each agent independently (agent id `handle').
Parameters
-------
handle : int (optional)
Handle of the agent for which to compute the observation vector.
Returns
-------
function
Transition map as local window of size x,y , agent-positions if in window and target.
"""
map_ = self.env.rail.grid.astype(np.float)/65535.0
target = self.env.agents_target[handle]
position = self.env.agents_position[handle]
a_, target = self.get_nearest_in_scope(position,self.size_, target)
map_ = map_[a_[0]:a_[0]+self.size_[0],a_[1]:a_[1]+self.size_[1]]
agent_positions_ = np.zeros_like(map_)
agent_handles = self.env.get_agent_handles()
for handle_ in agent_handles:
direction_ = float(self.env.agents_direction[handle_] + 1)/5.0
position_ = self.env.agents_position[handle_]
if position_[0]>= a_[0] and position_[0]<a_[0]+self.size_[0] \
and position_[1]>= a_[1] and position_[1]<a_[1]+self.size_[1]:
agent_positions_[position_[0]-a_[0]][position_[1]-a_[1]] = direction_
my_position_ = np.zeros_like(map_)
direction = float(self.env.agents_direction[handle] + 1)/5.0
my_position_[position[0]-a_[0]][position[1]-a_[1]] = direction
my_target_ = np.zeros_like(map_)
my_target_[target[0]-a_[0]][target[1]-a_[1]] = 1
return np.stack(( map_,agent_positions_,my_position_,my_target_))
def fake_generator(rail_l_map):
def generator(width, height, num_resets=0):
return copy.deepcopy(rail_l_map)
return generator
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('action', choices=['generate', 'test','play','manual'])
args = parser.parse_args()
rt = None
if args.action == 'generate':
#rail_l_map_ = generate_map()
#save_map(rail_l_map_)
link_ = FlatConvert(size_x,size_y)
link_.play_epochs = 1
link_.play_game_number = 1
link_.replay_buffer_size = 1000
env = RailEnv(32,32,rail_generator=complex_rail_generator(5,40,10),number_of_agents=5,obs_builder_object=RawObservation([size_x,size_y]))
rt = None#RenderTool(env,gl="QT")
link_.perform_training(env,rt)
elif args.action == 'test':
"""
primitive test
"""
link_ = FlatLink()
link_.collect_training_data(fakeEnv())
link_.trainer()
elif args.action == 'play':
link_ = FlatConvert(64,64)
rail_l_map = load_map()
env = RailEnv(64,64,rail_generator=complex_rail_generator(5,10,5),number_of_agents=5,obs_builder_object=RawObservation([size_x,size_y]))
rt = RenderTool(env,gl="QT")
link_.play(env,rt,"models/model_a3c_i.pt")
elif args.action == 'manual':
link_ = FlatLink()
link_.load_("D:\\ZHAW Master\\FS19\\VT2\\flatland-master_6\\flatland\\baselines\\Nets\\avoid_checkpoint15000.pth")
#from scipy import stats
#for j in range(10):
# su = 0
# size_ = 100
#print("stats.norm(1).rvs() ",stats.norm(j).rvs())
# for i in range(size_):
# su += int(abs(stats.norm(1).rvs()) < 0.1)
# print("hits ",float(su)/float(size_))
#for i in range(100):
# link_.test_random_move()
\ No newline at end of file
"""
"""
import torch.nn as nn
import torch
import torch.nn.functional as F
class FlatNet(nn.Module):
"""
Definition of a3c model, forward pass and loss.
"""
def __init__(self,size_x=20,size_y=10,init_weights=True):
super(FlatNet, self).__init__()
self.observations_= 1600 #size_x*size_y*4# 9216# 256
self.num_actions_= 4
self.loss_value_ = 0.01
self.loss_entropy_ = 0.05
self.epsilon_ = 1.e-10
self.avgpool2 = nn.AdaptiveAvgPool2d((5, 5))
self.avgpool3 = nn.AdaptiveAvgPool3d((1, 5, 5))
self.feat_1 = nn.Sequential(
nn.Conv3d(3, 16, kernel_size=5, padding=2),
nn.Conv3d(16, 16, kernel_size=5, padding=2),
nn.Conv3d(16, 16, kernel_size=3, padding=1),
)
self.feat_1b = nn.Sequential(
nn.Conv3d(3, 16, kernel_size=3, padding=1),
)
self.feat_2 = nn.Sequential(
nn.Conv3d(32, 32, kernel_size=5, padding=2),
nn.Conv3d(32, 32, kernel_size=5, padding=2),
)
self.feat_2b = nn.Sequential(
nn.Conv3d(32, 32, kernel_size=1, padding=0),
)
self.classifier_ = nn.Sequential(
nn.Linear(self.observations_, 512),
nn.LeakyReLU(True),
nn.Dropout(),
nn.Linear(512, 256),
nn.LeakyReLU(True),
)
self.gru_ = nn.GRU(256, 256, 5)
self.policy_ = nn.Sequential(
nn.Linear(256, self.num_actions_),
nn.Softmax(),
)
self.value_ = nn.Sequential(nn.Linear(256,1))
if init_weights==True:
self._initialize_weights(self.classifier_)
self._initialize_weights(self.policy_)
self._initialize_weights(self.value_)
self._initialize_weights(self.feat_1)
self._initialize_weights(self.feat_1b)
self._initialize_weights(self.feat_2)
self._initialize_weights(self.feat_2b)
def loss(self,out_policy, y_policy, out_values, y_values):
log_prob = torch.log(torch.sum(out_policy*y_policy, 1)+self.epsilon_);
advantage = y_values - out_values
#advantage = advantage.detach()#requires_grad_(requires_grad=False)
policy_loss = -log_prob * advantage.detach()
value_loss = self.loss_value_ * advantage.pow(2)
entropy = self.loss_entropy_ * torch.sum(y_policy * torch.log(y_policy+self.epsilon_), 1)
loss = torch.mean(policy_loss + value_loss + entropy)
return loss
def features_1(self,x):
x1 = self.feat_1(x)
x1b = self.feat_1b(x)
return torch.cat([x1,x1b],1)
def features_2(self,x):
x2 = self.feat_2(x)
x2b = self.feat_2b(x)
return torch.cat([x2,x2b],1)
def features(self,x):
x = F.leaky_relu(self.features_1(x),True)
x = F.leaky_relu(self.features_2(x),True)
x = F.max_pool3d(x,kernel_size=2, stride=2)
x = self.avgpool3(x)
return x
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier_(x)
x,_ = self.gru_(x.view(len(x), 1, -1))
x = x.view(len(x), -1)
p = self.policy_(x)
v = self.value_(x)
return p,v
def _initialize_weights(self,modules_):
for i,m in enumerate(modules_):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment