import numpy as np
from model import FlatNet
from loader import FlatData
import torch
import torch.nn as nn
import torch.optim as optim
from scipy import stats
import queue
import random
class FlatLink():
    """
    Linker for environment.
    Training: Generate training data, accumulate n-step reward and train model.
    Play: Execute trained model.
    """
    def __init__(self,size_x=20,size_y=10, gpu_ = 0):
        self.n_step_return_ = 50
        self.gamma_ = 0.5
        self.batch_size = 1000
        self.num_epochs = 20
        self.play_epochs = 50
        self.a = 3.57
        self.thres = 2.00
        self.alpha_rv = stats.alpha(self.a)
        self.alpha_rv.random_state = np.random.RandomState(seed=342423)
        self.model = FlatNet(size_x,size_y)
        self.queue = None
        self.max_iter = 300
        self.dtype = torch.float
        self.play_game_number = 10
        self.rate_decay = 0.97
        self.size_x=size_x
        self.size_y=size_y
        self.replay_buffer_size = 4000
        self.buffer_multiplicator = 1
        self.best_loss = 100000.0
        self.min_thres = 0.05

        if gpu_>-1:
            self.cuda = True
            self.device = torch.device('cuda', gpu_)
            self.model.cuda()
        else:
            self.cuda = False
            self.device = torch.device('cpu', 0)

    def load_(self, file='model_a3c_i.pt'):
        self.model = torch.load(file)
        print(self.model)
    
    def test_random_move(self):
        move =  np.argmax(self.alpha_rv.rvs(size=5))
        print("move ",move)
   
    def get_action(self, policy):
        """
        Return either predicted (from policy) or random action
        """
        if np.abs(stats.norm(1).rvs()) < self.thres:
            move =  np.argmax(self.alpha_rv.rvs(size=len(policy)))
            return move, 0
        else:
            xk = np.arange(len(policy))
            custm = stats.rv_discrete(name='custm', values=(xk, policy))
            return custm.rvs(), 1

    def accumulate(self, memory, state_value):

        """
        memory has form (state, action, reward, after_state, done)
        Accumulate b-step reward and put to training-queue
        """
        n = min(len(memory), self.n_step_return_)
        curr_state = memory[0][0]
        action =memory[0][1]
        
        n_step_return = 0
        for i in range(n):
            reward  = memory[i][2]
            n_step_return += np.power(self.gamma_,i+1) * reward
        
        done = memory[-1][4]

        if not done:
            n_step_return += (-1)*np.power(self.gamma_,n+1)*np.abs(state_value)[0]
        else:
            n_step_return += np.abs(state_value)[0] # Not neccessary!

        if len(memory)==1:
            memory = []
        else:
            memory = memory[1:-1]
        n_step_return = np.clip(n_step_return, -10., 1.)
        return curr_state, action, n_step_return , memory

   
    def training(self,observations,targets,rewards):
        """
        Run model training on accumulated training experience
        """
        self.model.train()
        data_train = FlatData(observations,targets,rewards,self.model.num_actions_,self.device, self.dtype)
        dataset_sizes = len(data_train)
        dataloader = torch.utils.data.DataLoader(data_train, batch_size=self.batch_size, shuffle=False,num_workers=0)
 
        optimizer= optim.Adam(self.model.parameters(),lr=0.0001,weight_decay=0.02)
        # TODO early stop
        for epoch in range(self.num_epochs):
            running_loss = 0.0
            for observation, target, reward in dataloader:
                optimizer.zero_grad()
                policy, value = self.model.forward(observation)
                loss = self.model.loss(policy,target,value,reward)
                loss.backward()
                optimizer.step()
                running_loss += torch.abs(loss).item() * observation.size(0)
               
            epoch_loss = running_loss / dataset_sizes
            print('Epoch {}/{}, Loss {}'.format(epoch, self.num_epochs - 1,epoch_loss))
            if epoch_loss < self.best_loss:
                self.best_loss = epoch_loss
                try:
                    torch.save(self.model, 'model_a3c_i.pt')
                except:
                    print("Model is not saved!!!")


    def predict(self, observation):
        """
        Forward pass on model
        Returns: Policy, Value
        """
        val = torch.from_numpy(observation.astype(np.float)).to(device = self.device, dtype=self.dtype)
        p, v = self.model.forward(val)
        return p.cpu().detach().numpy().reshape(-1), v.cpu().detach().numpy().reshape(-1)

    
    def trainer(self,buffer_list = []):
        """
        Call training sequence if buffer not emtpy
        """
        if buffer_list==[]:
            return
        curr_state_, action, n_step_return = zip(*buffer_list)
        self.training( curr_state_, action, n_step_return)
        #self.queue.queue.clear()
        try:
            torch.save(self.model, 'model_a3c.pt')
        except:
            print("Model is not saved!!!")


class FlatConvert(FlatLink):
    def __init__(self,size_x,size_y, gpu_ = 0):
        super(FlatConvert, self).__init__(size_x,size_y,gpu_)


    def perform_training(self,env,env_renderer=None):
        """
        Run simulation in training mode
        """
        self.queue = queue.Queue()
        buffer_list=[]
        #self.load_('model_a3c.pt')
        print("start training...")
        for j in range(self.play_epochs):
            reward_mean= 0.0
            predicted_mean = 0.0
            steps_mean = 0.0
            for i in range(self.play_game_number):
                predicted_, reward_ ,steps_= self.collect_training_data(env,env_renderer)
                reward_mean+=reward_
                predicted_mean+=predicted_
                steps_mean += steps_
                list_ = list(self.queue.queue) 

                buffer_list = buffer_list + list_ #+ [ [x,y,z] for x,y,z,d in random.sample(list_,count_)]
            reward_mean /= float(self.play_game_number)
            predicted_mean /= float(self.play_game_number)
            steps_mean /= float(self.play_game_number)
            print("play epoch: ",j,", total buffer size: ", len(buffer_list))
            if len(buffer_list) > self.replay_buffer_size * 1.2:
                list_2 =[ [x,y,z] for x,y,z in sorted(buffer_list, key=lambda pair: pair[2],reverse=True)] 
                size_  = int(self.replay_buffer_size/4)
                buffer_list = random.sample(buffer_list,size_*2) + list_2[:size_] + list_2[-size_:]
                random.shuffle(buffer_list)

            print("play epoch: ",j,", buffer size: ", len(buffer_list),", reward (mean): ",reward_mean,
            ", steps (mean): ",steps_mean,", predicted (mean): ",predicted_mean)
            self.trainer(buffer_list)
            self.queue.queue.clear()
            self.thres = max( self.thres*self.rate_decay,self.min_thres)
    

    def make_state(self,state):
        """
        Stack states 3 times for initial call (t1==t2==t3)
        """
        return np.stack((state,state,state))

    def update_state(self,states_,state):
        """
        update state by removing oldest timestep and adding actual step
        """
        states_[:(states_.shape[0]-1),:,:,:] =  states_[1:,:,:,:]
        states_[(state.shape[0]-1):,:,:,:] = state.reshape((1,state.shape[0],state.shape[1],state.shape[2]))
        return states_

    def init_state_dict(self,state_dict_):
        for key in state_dict_.keys():
            state_dict_[key] = self.make_state(state_dict_[key])
        return state_dict_

    def update_state_dict(self,old_state_dict, new_state_dict):
        for key in old_state_dict.keys():
            new_state_dict[key] = self.update_state(old_state_dict[key],new_state_dict[key])
        return new_state_dict

    def play(self,env,env_renderer=None,filename = 'model_a3c.pt',epochs_ =10):
        """
        Run simulation on trained model
        """
        episode_reward = 0
        self.thres = 0.0
        self.load_(filename)
        self.model.eval()
        global_reward = 0
        for i in range(epochs_):
            policy_dict = {}
            value_dict = {}
            action_dict = {}
            #self.max_iter = 40
            iter_ = 0
            pre_ = 0
            act_ = 0
            done = False
            curr_state_dict = self.init_state_dict(env.reset())
            #curr_state_dict = env.reset()
            while not done and iter_< self.max_iter:
                iter_+=1
                for agent_id, state in curr_state_dict.items():
                    policy, value = self.predict(state.astype(np.float).reshape((1,state.shape[0],state.shape[1],state.shape[2],state.shape[3])))
                    policy_dict[agent_id] = policy
                    value_dict[agent_id] = value
                    action_dict[agent_id], pre = self.get_action(policy)
                    pre_ += pre
                    act_+=1 
                if env_renderer is not None:    
                    env_renderer.renderEnv(show=True)
                next_state_dict, reward_dict, done_dict, _ = env.step(action_dict)
                
                next_state_dict = self.update_state_dict(curr_state_dict, next_state_dict)
                done = done_dict['__all__']
                
                curr_state_dict = next_state_dict

        return  float(pre_)/float(act_), global_reward, float(iter_)/float(self.max_iter)


    def collect_training_data(self,env,env_renderer = None):
        """
        Run single simualtion szenario for training
        """
        done = False
        curr_state_dict = self.init_state_dict(env.reset())
        #curr_state_dict = env.reset()
        episode_reward = 0
        memory={}
        for agent_id, state in curr_state_dict.items():
            memory[agent_id] = []
        self.model.eval()
        policy_dict = {}
        value_dict = {}
        action_dict = {}
        iter_ = 0
        pre_ = 0
        act_ = 0
        global_reward = 0.0
        while not done and iter_< self.max_iter:
            iter_+=1
            for agent_id, state in curr_state_dict.items():
                policy, value = self.predict(state.astype(np.float).reshape((1,state.shape[0],state.shape[1],state.shape[2],state.shape[3])))
                policy_dict[agent_id] = policy
                value_dict[agent_id] = value
                action_dict[agent_id], pre = self.get_action(policy)
                pre_ += pre
                act_+=1 
            if env_renderer is not None:    
                env_renderer.renderEnv(show=True)
            next_state_dict, reward_dict, done_dict, _ = env.step(action_dict)
            next_state_dict = self.update_state_dict(curr_state_dict, next_state_dict)
            #reward_dict = self.modify_reward(reward_dict,done_dict,next_state_dict)
            done = done_dict['__all__']
            for agent_id, state in curr_state_dict.items():
            
                memory[agent_id].append(tuple([curr_state_dict[agent_id], 
                action_dict[agent_id], reward_dict[agent_id], next_state_dict[agent_id], done_dict[agent_id]]))
                global_reward += reward_dict[agent_id]
                while ((len(memory[agent_id]) >= self.n_step_return_) or (done and not memory[agent_id] == [])):
                    curr_state, action, n_step_return, memory[agent_id] = self.accumulate(memory[agent_id], 
                    value_dict[agent_id])
                    self.queue.put([curr_state, action, n_step_return])


            curr_state_dict = next_state_dict

        return  float(pre_)/float(act_), global_reward, float(iter_)/float(self.max_iter)