From faed438e3d249893d975febc4591a1a97cb931bf Mon Sep 17 00:00:00 2001 From: stephan <scarlatti@MSI> Date: Fri, 17 May 2019 11:20:14 +0200 Subject: [PATCH] First prototype of a3c baseline - experimental! --- a3c/README.md | 55 +++++++++ a3c/fake.py | 28 +++++ a3c/linker.py | 307 ++++++++++++++++++++++++++++++++++++++++++++++++++ a3c/loader.py | 25 ++++ a3c/main.py | 239 +++++++++++++++++++++++++++++++++++++++ a3c/model.py | 116 +++++++++++++++++++ 6 files changed, 770 insertions(+) create mode 100644 a3c/README.md create mode 100644 a3c/fake.py create mode 100644 a3c/linker.py create mode 100644 a3c/loader.py create mode 100644 a3c/main.py create mode 100644 a3c/model.py diff --git a/a3c/README.md b/a3c/README.md new file mode 100644 index 0000000..e2ed683 --- /dev/null +++ b/a3c/README.md @@ -0,0 +1,55 @@ +Experimental a3c implementation +=============================== + +Dependecies +----------- +- pytorch 1.0 with gpu support +- numpy + +Installation +------------ + 1.) Put all files in this folder into same folder as flatland + **OR** + 2.) Add path to flatland in "main.py" to sys.path + +Running +------- +* Training: python3 main.py generate +* Replay: python3 main.py test + +Settings +-------- + +**PLEASE NOTE: THIS CODE IS EXPERIMENTAL AND AS IT IS!** + +Tested using complex rail generator with 5 Agents and 4o extra connections. The gridworld is 32x32, local observation is a cropped region of 8x8 out of the gridworld. Observation consist 3 temporal step + +* Transitions: 8x8 tensor float +* Positions ( all agents): 8x8 tensor float +* Position (agent x): 8x8 tensor float +* Target (agent x): 8x8 tensor float + +Fusion of 3 temporal steps -> state has size 3x4x8x8 + +The grid-size and the region-size are hardcoded in main.py! + +Hyperparameters +--------------- +FlatLink.n_step_return_ = 50 , length of n-step return +FlatLink.gamma_ = 0.5, discount factor +FlatLink.num_epochs = 20, epochs used for training model +FlatLink.play_epochs = 50, overll number of training epochs (number of sequences play, train-model) +FlatLink.thres = 2.00, threshold for choose random or predicted action +FlatLink.max_iter = 300, maximum number steps allowed per game +FlatLink.play_game_number = 10, number of games played per training epoch +FlatLink.rate_decay = 0.97, decay rate for thres per epoch +FlatLink.replay_buffer_size = 4000, replay buffer size (balanced buffer) +FlatLink.min_thres = 0.05, minimum value for thres + +FlatNet.loss_value_ = 0.01, weight for value loss +FlatNet.loss_entropy_ = 0.05, weight for entropy loss + +Remarks +------- +Convergation rate is very slow, no proper training weigths available at this time. + diff --git a/a3c/fake.py b/a3c/fake.py new file mode 100644 index 0000000..d72400b --- /dev/null +++ b/a3c/fake.py @@ -0,0 +1,28 @@ +import numpy as np + +class fakeEnv(): + """ + Dummy environment for mock test of a3c. + """ + def __init__(self): + self.grid_size = tuple([20,10]) + self.num_agents = 2 + self.done = False + self.iterate = 0 + self.grid = np.zeros(self.grid_size) + + def reset(self): + self.done = False + self.iterate = 0 + grid_ = self.grid + return grid_ + + def step(self, action): + grid_ = self.grid + fake_rail = np.random.rand(self.grid_size[0],self.grid_size[1]) + grid_[fake_rail<0.25]=0.25 + grid_[fake_rail>0.45]=0.5 + reward = np.array(np.mean(fake_rail)/self.grid_size[0]/self.grid_size[1]) + self.iterate+=1 + return grid_,reward,self.iterate > 100, None + diff --git a/a3c/linker.py b/a3c/linker.py new file mode 100644 index 0000000..a5c10af --- /dev/null +++ b/a3c/linker.py @@ -0,0 +1,307 @@ +import numpy as np +from model import FlatNet +from loader import FlatData +import torch +import torch.nn as nn +import torch.optim as optim +from scipy import stats +import queue +import random +class FlatLink(): + """ + Linker for environment. + Training: Generate training data, accumulate n-step reward and train model. + Play: Execute trained model. + """ + def __init__(self,size_x=20,size_y=10, gpu_ = 0): + self.n_step_return_ = 50 + self.gamma_ = 0.5 + self.batch_size = 1000 + self.num_epochs = 20 + self.play_epochs = 50 + self.a = 3.57 + self.thres = 2.00 + self.alpha_rv = stats.alpha(self.a) + self.alpha_rv.random_state = np.random.RandomState(seed=342423) + self.model = FlatNet(size_x,size_y) + self.queue = None + self.max_iter = 300 + self.dtype = torch.float + self.play_game_number = 10 + self.rate_decay = 0.97 + self.size_x=size_x + self.size_y=size_y + self.replay_buffer_size = 4000 + self.buffer_multiplicator = 1 + self.best_loss = 100000.0 + self.min_thres = 0.05 + + if gpu_>-1: + self.cuda = True + self.device = torch.device('cuda', gpu_) + self.model.cuda() + else: + self.cuda = False + self.device = torch.device('cpu', 0) + + def load_(self, file='model_a3c_i.pt'): + self.model = torch.load(file) + print(self.model) + + def test_random_move(self): + move = np.argmax(self.alpha_rv.rvs(size=5)) + print("move ",move) + + def get_action(self, policy): + """ + Return either predicted (from policy) or random action + """ + if np.abs(stats.norm(1).rvs()) < self.thres: + move = np.argmax(self.alpha_rv.rvs(size=len(policy))) + return move, 0 + else: + xk = np.arange(len(policy)) + custm = stats.rv_discrete(name='custm', values=(xk, policy)) + return custm.rvs(), 1 + + def accumulate(self, memory, state_value): + + """ + memory has form (state, action, reward, after_state, done) + Accumulate b-step reward and put to training-queue + """ + n = min(len(memory), self.n_step_return_) + curr_state = memory[0][0] + action =memory[0][1] + + n_step_return = 0 + for i in range(n): + reward = memory[i][2] + n_step_return += np.power(self.gamma_,i+1) * reward + + done = memory[-1][4] + + if not done: + n_step_return += (-1)*np.power(self.gamma_,n+1)*np.abs(state_value)[0] + else: + n_step_return += np.abs(state_value)[0] # Not neccessary! + + if len(memory)==1: + memory = [] + else: + memory = memory[1:-1] + n_step_return = np.clip(n_step_return, -10., 1.) + return curr_state, action, n_step_return , memory + + + def training(self,observations,targets,rewards): + """ + Run model training on accumulated training experience + """ + self.model.train() + data_train = FlatData(observations,targets,rewards,self.model.num_actions_,self.device, self.dtype) + dataset_sizes = len(data_train) + dataloader = torch.utils.data.DataLoader(data_train, batch_size=self.batch_size, shuffle=False,num_workers=0) + + optimizer= optim.Adam(self.model.parameters(),lr=0.0001,weight_decay=0.02) + # TODO early stop + for epoch in range(self.num_epochs): + running_loss = 0.0 + for observation, target, reward in dataloader: + optimizer.zero_grad() + policy, value = self.model.forward(observation) + loss = self.model.loss(policy,target,value,reward) + loss.backward() + optimizer.step() + running_loss += torch.abs(loss).item() * observation.size(0) + + epoch_loss = running_loss / dataset_sizes + print('Epoch {}/{}, Loss {}'.format(epoch, self.num_epochs - 1,epoch_loss)) + if epoch_loss < self.best_loss: + self.best_loss = epoch_loss + try: + torch.save(self.model, 'model_a3c_i.pt') + except: + print("Model is not saved!!!") + + + def predict(self, observation): + """ + Forward pass on model + Returns: Policy, Value + """ + val = torch.from_numpy(observation.astype(np.float)).to(device = self.device, dtype=self.dtype) + p, v = self.model.forward(val) + return p.cpu().detach().numpy().reshape(-1), v.cpu().detach().numpy().reshape(-1) + + + def trainer(self,buffer_list = []): + """ + Call training sequence if buffer not emtpy + """ + if buffer_list==[]: + return + curr_state_, action, n_step_return = zip(*buffer_list) + self.training( curr_state_, action, n_step_return) + #self.queue.queue.clear() + try: + torch.save(self.model, 'model_a3c.pt') + except: + print("Model is not saved!!!") + + +class FlatConvert(FlatLink): + def __init__(self,size_x,size_y, gpu_ = 0): + super(FlatConvert, self).__init__(size_x,size_y,gpu_) + + + def perform_training(self,env,env_renderer=None): + """ + Run simulation in training mode + """ + self.queue = queue.Queue() + buffer_list=[] + self.load_('model_a3c.pt') + print("start training...") + for j in range(self.play_epochs): + reward_mean= 0.0 + predicted_mean = 0.0 + steps_mean = 0.0 + for i in range(self.play_game_number): + predicted_, reward_ ,steps_= self.collect_training_data(env,env_renderer) + reward_mean+=reward_ + predicted_mean+=predicted_ + steps_mean += steps_ + list_ = list(self.queue.queue) + + buffer_list = buffer_list + list_ #+ [ [x,y,z] for x,y,z,d in random.sample(list_,count_)] + reward_mean /= float(self.play_game_number) + predicted_mean /= float(self.play_game_number) + steps_mean /= float(self.play_game_number) + print("play epoch: ",j,", total buffer size: ", len(buffer_list)) + if len(buffer_list) > self.replay_buffer_size * 1.2: + list_2 =[ [x,y,z] for x,y,z in sorted(buffer_list, key=lambda pair: pair[2],reverse=True)] + size_ = int(self.replay_buffer_size/4) + buffer_list = random.sample(buffer_list,size_*2) + list_2[:size_] + list_2[-size_:] + random.shuffle(buffer_list) + + print("play epoch: ",j,", buffer size: ", len(buffer_list),", reward (mean): ",reward_mean, + ", steps (mean): ",steps_mean,", predicted (mean): ",predicted_mean) + self.trainer(buffer_list) + self.queue.queue.clear() + self.thres = max( self.thres*self.rate_decay,self.min_thres) + + + def make_state(self,state): + """ + Stack states 3 times for initial call (t1==t2==t3) + """ + return np.stack((state,state,state)) + + def update_state(self,states_,state): + """ + update state by removing oldest timestep and adding actual step + """ + states_[:(states_.shape[0]-1),:,:,:] = states_[1:,:,:,:] + states_[(state.shape[0]-1):,:,:,:] = state.reshape((1,state.shape[0],state.shape[1],state.shape[2])) + return states_ + + def init_state_dict(self,state_dict_): + for key in state_dict_.keys(): + state_dict_[key] = self.make_state(state_dict_[key]) + return state_dict_ + + def update_state_dict(self,old_state_dict, new_state_dict): + for key in old_state_dict.keys(): + new_state_dict[key] = self.update_state(old_state_dict[key],new_state_dict[key]) + return new_state_dict + + def play(self,env,env_renderer=None,filename = 'model_a3c.pt',epochs_ =10): + """ + Run simulation on trained model + """ + episode_reward = 0 + self.thres = 0.0 + self.load_(filename) + self.model.eval() + global_reward = 0 + for i in range(epochs_): + policy_dict = {} + value_dict = {} + action_dict = {} + #self.max_iter = 40 + iter_ = 0 + pre_ = 0 + act_ = 0 + done = False + curr_state_dict = self.init_state_dict(env.reset()) + #curr_state_dict = env.reset() + while not done and iter_< self.max_iter: + iter_+=1 + for agent_id, state in curr_state_dict.items(): + policy, value = self.predict(state.astype(np.float).reshape((1,state.shape[0],state.shape[1],state.shape[2],state.shape[3]))) + policy_dict[agent_id] = policy + value_dict[agent_id] = value + action_dict[agent_id], pre = self.get_action(policy) + pre_ += pre + act_+=1 + if env_renderer is not None: + env_renderer.renderEnv(show=True) + next_state_dict, reward_dict, done_dict, _ = env.step(action_dict) + + next_state_dict = self.update_state_dict(curr_state_dict, next_state_dict) + done = done_dict['__all__'] + + curr_state_dict = next_state_dict + + return float(pre_)/float(act_), global_reward, float(iter_)/float(self.max_iter) + + + def collect_training_data(self,env,env_renderer = None): + """ + Run single simualtion szenario for training + """ + done = False + curr_state_dict = self.init_state_dict(env.reset()) + #curr_state_dict = env.reset() + episode_reward = 0 + memory={} + for agent_id, state in curr_state_dict.items(): + memory[agent_id] = [] + self.model.eval() + policy_dict = {} + value_dict = {} + action_dict = {} + iter_ = 0 + pre_ = 0 + act_ = 0 + global_reward = 0.0 + while not done and iter_< self.max_iter: + iter_+=1 + for agent_id, state in curr_state_dict.items(): + policy, value = self.predict(state.astype(np.float).reshape((1,state.shape[0],state.shape[1],state.shape[2],state.shape[3]))) + policy_dict[agent_id] = policy + value_dict[agent_id] = value + action_dict[agent_id], pre = self.get_action(policy) + pre_ += pre + act_+=1 + if env_renderer is not None: + env_renderer.renderEnv(show=True) + next_state_dict, reward_dict, done_dict, _ = env.step(action_dict) + next_state_dict = self.update_state_dict(curr_state_dict, next_state_dict) + #reward_dict = self.modify_reward(reward_dict,done_dict,next_state_dict) + done = done_dict['__all__'] + for agent_id, state in curr_state_dict.items(): + + memory[agent_id].append(tuple([curr_state_dict[agent_id], + action_dict[agent_id], reward_dict[agent_id], next_state_dict[agent_id], done_dict[agent_id]])) + global_reward += reward_dict[agent_id] + while ((len(memory[agent_id]) >= self.n_step_return_) or (done and not memory[agent_id] == [])): + curr_state, action, n_step_return, memory[agent_id] = self.accumulate(memory[agent_id], + value_dict[agent_id]) + self.queue.put([curr_state, action, n_step_return]) + + + curr_state_dict = next_state_dict + + return float(pre_)/float(act_), global_reward, float(iter_)/float(self.max_iter) \ No newline at end of file diff --git a/a3c/loader.py b/a3c/loader.py new file mode 100644 index 0000000..589f9a4 --- /dev/null +++ b/a3c/loader.py @@ -0,0 +1,25 @@ +import torch +import torch.nn as nn +from torchvision import utils +import numpy as np + +class FlatData(torch.utils.data.dataset.Dataset): + """ + Prepare data to read batch-wise, create dataset + + """ + def __init__(self,states,targets,rewards,num_action,device,dtype): + self.states_ = [torch.from_numpy(x.astype(np.float)).to(device = device, dtype=dtype) for x in states] + self.targets_ = [torch.from_numpy(np.eye(num_action)[x]).to(device = device, dtype=dtype) for x in targets] + self.rewards_ = [torch.from_numpy(np.array([x]).reshape(-1)).to(device = device, dtype=dtype) for x in rewards] + + + # Override to give PyTorch access to any image on the dataset + def __getitem__(self, index): + + return self.states_[index], self.targets_[index], self.rewards_[index] + + # Override to give PyTorch size of dataset + def __len__(self): + return len(self.states_) + diff --git a/a3c/main.py b/a3c/main.py new file mode 100644 index 0000000..24e27a6 --- /dev/null +++ b/a3c/main.py @@ -0,0 +1,239 @@ +import argparse +from linker import FlatLink, FlatConvert +from fake import fakeEnv +import sys +import copy +import numpy as np +sys.path.append("D:\\ZHAW Master\\FS19\VT2\\flatland-master_6") +#sys.path.insert(0, "D:\\ZHAW Master\\FS19\VT2\\flatland-master\\flatland") +from flatland.envs.rail_env import RailEnv +from flatland.envs.generators import random_rail_generator, complex_rail_generator,rail_from_GridTransitionMap_generator +from flatland.core.env_observation_builder import ObservationBuilder +from flatland.core.transition_map import GridTransitionMap +from flatland.utils.rendertools import RenderTool +import time +import pickle +size_x=8 +size_y=8 + + +def save_map(rail): + np.save("my_rail_grid.npy", rail.grid) + np.save("my_rail_trans_list.npy",np.array(rail.transitions.transition_list)) + np.save("my_rail_trans.npy",np.array(rail.transitions.transitions)) + +def load_map(): + lmap = GridTransitionMap(12,12) + lmap.grid = np.load("grid/my_rail_grid.npy") + lmap.transitions.transition_list =list( np.load("grid/my_rail_trans_list.npy")) + lmap.transitions.transitions = list(np.load("grid/my_rail_trans.npy")) + return lmap +""" +def generate_map(rail_generator=random_rail_generator()): + + t = time.time() + geny = rail_generator + rail = geny(size_x,size_y) + elapsed = time.time() - t + print("initialization time: ", elapsed) + return rail +""" + +class RawObservation(ObservationBuilder): + """ + ObservationBuilder for raw observations. + """ + def __init__(self, size_): + self.reset() + self.size_ = size_ + + def _set_env(self, env): + self.env = env + + def reset(self): + """ + Called after each environment reset. + """ + self.map_ = None + self.agent_positions_ = None + self.agent_handles_ = None + + def search_grid(self,x_indices,y_indices): + b_ = None + if x_indices!=[] and y_indices!=[]: + for i in x_indices: + for j in y_indices: + if int(self.env.rail.grid[i][j]) != 0: + b_ = [i,j] + break + return b_ + + def get_nearest_in_scope(self,position_,size_,target_): + """ + Crop region of size x,y around position and set a intemediate target if + real target is not in local window. + """ + shape_ = self.env.rail.grid.shape + x_ = position_[0] - target_[0] + y_ = position_[1] - target_[1] + a_ = None + b_ = None + + if np.abs(x_) >= size_[0] or np.abs(y_) >= size_[1]: + xd_ = x_ + yd_ = y_ + if np.abs(x_) >= size_[0]: + xd_ = int(xd_ * size_[0]/np.abs(x_)) + if np.abs(y_) >= size_[1]: + yd_ = int(yd_ * size_[1]/np.abs(y_)) + #bug here! + x_indices = [] + y_indices = [] + if xd_ < 0: + x_indices = list(reversed(range(position_[0] ,position_[0] - xd_ -1))) + elif xd_ > 0: + x_indices = list(range(position_[0] -xd_ + 1,position_[0])) + if yd_ < 0: + y_indices = list(reversed(range(position_[1] ,position_[1] - yd_ -1))) + elif yd_ > 0: + y_indices = list(range(position_[1] -yd_ +1 ,position_[1])) + + # cheat + x_indices.append(position_[0]) + y_indices.append(position_[1]) + # "nearest point in new area" + b_ = self.search_grid(x_indices, y_indices) + + + if b_ is not None: + target_ = b_ + else: + # take any existing point + if len (x_indices) < size_[0]/2 -1: + x_indices = list(range(position_[0] - size_x, position_[0]+size_x)) + if len (y_indices) < size_[y]/2 -1: + y_indices = list(range(position_[1] - size_y,position_[1]+size_y)) + + b_ = self.search_grid(x_indices, y_indices) + if b_ is None: + target_ = position_ + + x_ = position_[0] - target_[0] + y_ = position_[1] - target_[1] + + + if x_ > 0 and y_ > 0: + a_ = [target_[0], target_[1]] + elif x_<= 0 and y_ <= 0: + a_ = [position_[0], position_[1]] + elif x_> 0 and y_ <= 0: + a_ = [target_[0], position_[1]] + elif x_<= 0 and y_ > 0: + a_ = [position_[0], target_[1]] + + if a_[0] + size_[0] >= shape_[0]: + a_[0]-= a_[0] + size_[0] - shape_[0] + + if a_[1] + size_[1] >= shape_[1]: + a_[1]-= a_[1] +size_[1] - shape_[1] + + if position_[0]-a_[0] >= size_[0] or position_[1]-a_[1] >= size_[1]: + b_ = target_ + + + return a_, target_ + + def get(self, handle=0): + """ + Called whenever an observation has to be computed for the `env' environment, possibly + for each agent independently (agent id `handle'). + + Parameters + ------- + handle : int (optional) + Handle of the agent for which to compute the observation vector. + + Returns + ------- + function + Transition map as local window of size x,y , agent-positions if in window and target. + """ + + map_ = self.env.rail.grid.astype(np.float)/65535.0 + target = self.env.agents_target[handle] + position = self.env.agents_position[handle] + a_, target = self.get_nearest_in_scope(position,self.size_, target) + map_ = map_[a_[0]:a_[0]+self.size_[0],a_[1]:a_[1]+self.size_[1]] + + agent_positions_ = np.zeros_like(map_) + agent_handles = self.env.get_agent_handles() + for handle_ in agent_handles: + direction_ = float(self.env.agents_direction[handle_] + 1)/5.0 + position_ = self.env.agents_position[handle_] + if position_[0]>= a_[0] and position_[0]<a_[0]+self.size_[0] \ + and position_[1]>= a_[1] and position_[1]<a_[1]+self.size_[1]: + agent_positions_[position_[0]-a_[0]][position_[1]-a_[1]] = direction_ + + my_position_ = np.zeros_like(map_) + + direction = float(self.env.agents_direction[handle] + 1)/5.0 + my_position_[position[0]-a_[0]][position[1]-a_[1]] = direction + + my_target_ = np.zeros_like(map_) + + my_target_[target[0]-a_[0]][target[1]-a_[1]] = 1 + + return np.stack(( map_,agent_positions_,my_position_,my_target_)) + + +def fake_generator(rail_l_map): + def generator(width, height, num_resets=0): + + return copy.deepcopy(rail_l_map) + return generator + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('action', choices=['generate', 'test','play','manual']) + args = parser.parse_args() + rt = None + if args.action == 'generate': + #rail_l_map_ = generate_map() + #save_map(rail_l_map_) + link_ = FlatConvert(size_x,size_y) + link_.play_epochs = 1 + link_.play_game_number = 1 + link_.replay_buffer_size = 1000 + env = RailEnv(32,32,rail_generator=complex_rail_generator(5,40,10),number_of_agents=5,obs_builder_object=RawObservation([size_x,size_y])) + rt = None#RenderTool(env,gl="QT") + link_.perform_training(env,rt) + + elif args.action == 'test': + """ + primitive test + """ + + link_ = FlatLink() + link_.collect_training_data(fakeEnv()) + link_.trainer() + elif args.action == 'play': + link_ = FlatConvert(64,64) + rail_l_map = load_map() + env = RailEnv(64,64,rail_generator=complex_rail_generator(5,10,5),number_of_agents=5,obs_builder_object=RawObservation([size_x,size_y])) + rt = RenderTool(env,gl="QT") + link_.play(env,rt,"models/model_a3c_i.pt") + + elif args.action == 'manual': + + link_ = FlatLink() + link_.load_("D:\\ZHAW Master\\FS19\\VT2\\flatland-master_6\\flatland\\baselines\\Nets\\avoid_checkpoint15000.pth") + #from scipy import stats + #for j in range(10): + # su = 0 + # size_ = 100 + #print("stats.norm(1).rvs() ",stats.norm(j).rvs()) + # for i in range(size_): + # su += int(abs(stats.norm(1).rvs()) < 0.1) + # print("hits ",float(su)/float(size_)) + #for i in range(100): + # link_.test_random_move() \ No newline at end of file diff --git a/a3c/model.py b/a3c/model.py new file mode 100644 index 0000000..6980ed9 --- /dev/null +++ b/a3c/model.py @@ -0,0 +1,116 @@ +""" + +""" +import torch.nn as nn +import torch + +import torch.nn.functional as F +class FlatNet(nn.Module): + """ + Definition of a3c model, forward pass and loss. + """ + def __init__(self,size_x=20,size_y=10,init_weights=True): + super(FlatNet, self).__init__() + self.observations_= 1600 #size_x*size_y*4# 9216# 256 + self.num_actions_= 4 + self.loss_value_ = 0.01 + self.loss_entropy_ = 0.05 + self.epsilon_ = 1.e-10 + + self.avgpool2 = nn.AdaptiveAvgPool2d((5, 5)) + self.avgpool3 = nn.AdaptiveAvgPool3d((1, 5, 5)) + self.feat_1 = nn.Sequential( + nn.Conv3d(3, 16, kernel_size=5, padding=2), + nn.Conv3d(16, 16, kernel_size=5, padding=2), + nn.Conv3d(16, 16, kernel_size=3, padding=1), + ) + self.feat_1b = nn.Sequential( + nn.Conv3d(3, 16, kernel_size=3, padding=1), + ) + self.feat_2 = nn.Sequential( + nn.Conv3d(32, 32, kernel_size=5, padding=2), + nn.Conv3d(32, 32, kernel_size=5, padding=2), + ) + self.feat_2b = nn.Sequential( + nn.Conv3d(32, 32, kernel_size=1, padding=0), + ) + self.classifier_ = nn.Sequential( + nn.Linear(self.observations_, 512), + nn.LeakyReLU(True), + nn.Dropout(), + nn.Linear(512, 256), + nn.LeakyReLU(True), + ) + + self.gru_ = nn.GRU(256, 256, 5) + self.policy_ = nn.Sequential( + nn.Linear(256, self.num_actions_), + nn.Softmax(), + ) + self.value_ = nn.Sequential(nn.Linear(256,1)) + + + if init_weights==True: + self._initialize_weights(self.classifier_) + self._initialize_weights(self.policy_) + self._initialize_weights(self.value_) + self._initialize_weights(self.feat_1) + self._initialize_weights(self.feat_1b) + self._initialize_weights(self.feat_2) + self._initialize_weights(self.feat_2b) + + + def loss(self,out_policy, y_policy, out_values, y_values): + log_prob = torch.log(torch.sum(out_policy*y_policy, 1)+self.epsilon_); + advantage = y_values - out_values + #advantage = advantage.detach()#requires_grad_(requires_grad=False) + policy_loss = -log_prob * advantage.detach() + value_loss = self.loss_value_ * advantage.pow(2) + entropy = self.loss_entropy_ * torch.sum(y_policy * torch.log(y_policy+self.epsilon_), 1) + loss = torch.mean(policy_loss + value_loss + entropy) + return loss + + def features_1(self,x): + x1 = self.feat_1(x) + x1b = self.feat_1b(x) + return torch.cat([x1,x1b],1) + + def features_2(self,x): + x2 = self.feat_2(x) + x2b = self.feat_2b(x) + return torch.cat([x2,x2b],1) + + def features(self,x): + x = F.leaky_relu(self.features_1(x),True) + x = F.leaky_relu(self.features_2(x),True) + x = F.max_pool3d(x,kernel_size=2, stride=2) + x = self.avgpool3(x) + return x + + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0), -1) + x = self.classifier_(x) + x,_ = self.gru_(x.view(len(x), 1, -1)) + x = x.view(len(x), -1) + p = self.policy_(x) + v = self.value_(x) + return p,v + + def _initialize_weights(self,modules_): + for i,m in enumerate(modules_): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + \ No newline at end of file -- GitLab