From 1c6fe09fdd5fda93300e947d81deb74af670ec0d Mon Sep 17 00:00:00 2001
From: stephan <scarlatti@MSI>
Date: Wed, 17 Jul 2019 17:23:14 +0200
Subject: [PATCH] code for testing grid-based local observations
 (LocalObsForRailEnv)

---
 a3c/linker.py | 192 +++++++++++++++++++++++++++++-------------
 a3c/loader.py |   9 +-
 a3c/main.py   | 229 ++++++--------------------------------------------
 a3c/model.py  |  51 ++++++-----
 4 files changed, 196 insertions(+), 285 deletions(-)

diff --git a/a3c/linker.py b/a3c/linker.py
index e694aed..315a30e 100644
--- a/a3c/linker.py
+++ b/a3c/linker.py
@@ -7,6 +7,7 @@ import torch.optim as optim
 from scipy import stats
 import queue
 import random
+import copy
 class FlatLink():
     """
     Linker for environment.
@@ -14,28 +15,28 @@ class FlatLink():
     Play: Execute trained model.
     """
     def __init__(self,size_x=20,size_y=10, gpu_ = 0):
-        self.n_step_return_ = 50
-        self.gamma_ = 0.5
-        self.batch_size = 1000
-        self.num_epochs = 20
-        self.play_epochs = 50
+        self.n_step_return_ = 100
+        self.gamma_ = 0.99
+        self.batch_size = 4000
+        self.num_epochs = 100
+        self.play_epochs = 1000
         self.a = 3.57
-        self.thres = 2.00
+        self.thres = 1.0
         self.alpha_rv = stats.alpha(self.a)
         self.alpha_rv.random_state = np.random.RandomState(seed=342423)
         self.model = FlatNet(size_x,size_y)
         self.queue = None
-        self.max_iter = 300
+        self.max_iter =  self.n_step_return_
         self.dtype = torch.float
-        self.play_game_number = 10
-        self.rate_decay = 0.97
+        self.play_game_number = 100
+        self.rate_decay = 0.8
         self.size_x=size_x
         self.size_y=size_y
-        self.replay_buffer_size = 4000
+        self.replay_buffer_size = 10000
         self.buffer_multiplicator = 1
         self.best_loss = 100000.0
-        self.min_thres = 0.05
-
+        self.min_thres = 0.03
+        self.normalize_reward = 100
         if gpu_>-1:
             self.cuda = True
             self.device = torch.device('cuda', gpu_)
@@ -64,7 +65,7 @@ class FlatLink():
             custm = stats.rv_discrete(name='custm', values=(xk, policy))
             return custm.rvs(), 1
 
-    def accumulate(self, memory, state_value):
+    def accumulate(self, memory, state_value, done_):
 
         """
         memory has form (state, action, reward, after_state, done)
@@ -77,20 +78,32 @@ class FlatLink():
         n_step_return = 0
         for i in range(n):
             reward  = memory[i][2]
-            n_step_return += np.power(self.gamma_,i+1) * reward
+            n_step_return += np.power(self.gamma_,i) * reward
         
         done = memory[-1][4]
-
+        #if done:
+        #    print("n_step_return: ",n_step_return," done: ", done)
+        
         if not done:
-            n_step_return += (-1)*np.power(self.gamma_,n+1)*np.abs(state_value)[0]
+            n_step_return += np.power(self.gamma_,n)*state_value[0]
         else:
-            n_step_return += np.abs(state_value)[0] # Not neccessary!
+            n_step_return += 1.0
+        
+        if done_:
+            n_step_return += 10.0
+        #if not done:
+        #    n_step_return += (-1)*np.power(self.gamma_,n+1)*np.abs(state_value)[0]
+        #else:
+        #    n_step_return += np.abs(state_value)[0] # Not neccessary!
 
-        if len(memory)==1:
+        if len(memory)==1 or done:
             memory = []
         else:
             memory = memory[1:-1]
-        n_step_return = np.clip(n_step_return, -10., 1.)
+        n_step_return /= self.normalize_reward
+        #print("n_step_return: ",n_step_return," n: ", n)
+        n_step_return = np.clip(n_step_return, -1., 1.)
+        #print("n_step_return: ",n_step_return," 4")
         return curr_state, action, n_step_return , memory
 
    
@@ -101,28 +114,34 @@ class FlatLink():
         self.model.train()
         data_train = FlatData(observations,targets,rewards,self.model.num_actions_,self.device, self.dtype)
         dataset_sizes = len(data_train)
-        dataloader = torch.utils.data.DataLoader(data_train, batch_size=self.batch_size, shuffle=False,num_workers=0)
- 
+        dataloader = torch.utils.data.DataLoader(data_train, batch_size=self.batch_size, shuffle=True,num_workers=0)
+        best_model_wts = None
         optimizer= optim.Adam(self.model.parameters(),lr=0.0001,weight_decay=0.02)
+        best_loss  = 10000000
         # TODO early stop
         for epoch in range(self.num_epochs):
             running_loss = 0.0
-            for observation, target, reward in dataloader:
+            for observation, obs_2, target, reward in dataloader:
                 optimizer.zero_grad()
-                policy, value = self.model.forward(observation)
-                loss = self.model.loss(policy,target,value,reward)
+                policy, value = self.model.forward(observation.to(device = self.device, dtype=self.dtype),obs_2.to(device = self.device, dtype=self.dtype))
+                loss = self.model.loss(policy,target.to(device = self.device, dtype=self.dtype),value,reward.to(device = self.device, dtype=self.dtype))
                 loss.backward()
                 optimizer.step()
                 running_loss += torch.abs(loss).item() * observation.size(0)
                
             epoch_loss = running_loss / dataset_sizes
-            print('Epoch {}/{}, Loss {}'.format(epoch, self.num_epochs - 1,epoch_loss))
+           
             if epoch_loss < self.best_loss:
-                self.best_loss = epoch_loss
+                best_loss = epoch_loss
+                best_model_wts = copy.deepcopy(self.model.state_dict())
                 try:
                     torch.save(self.model, 'model_a3c_i.pt')
+                    
                 except:
                     print("Model is not saved!!!")
+        print('Last Loss {}, Best Loss {}'.format(epoch_loss, best_loss))
+        if best_model_wts is not None:
+            self.model.load_state_dict(best_model_wts)
 
 
     def predict(self, observation):
@@ -130,8 +149,13 @@ class FlatLink():
         Forward pass on model
         Returns: Policy, Value
         """
-        val = torch.from_numpy(observation.astype(np.float)).to(device = self.device, dtype=self.dtype)
-        p, v = self.model.forward(val)
+        state = observation[0]
+        state2 = observation[1]
+        observation1 = state.astype(np.float).reshape((1,state.shape[0],state.shape[1],state.shape[2]))
+        observation2 = state2.astype(np.float).reshape((1,state2.shape[0]))
+        val1 = torch.from_numpy(observation1.astype(np.float)).to(device = self.device, dtype=self.dtype)
+        val2 = torch.from_numpy(observation2.astype(np.float)).to(device = self.device, dtype=self.dtype)
+        p, v = self.model.forward(val1,val2)
         return p.cpu().detach().numpy().reshape(-1), v.cpu().detach().numpy().reshape(-1)
 
     
@@ -144,10 +168,10 @@ class FlatLink():
         curr_state_, action, n_step_return = zip(*buffer_list)
         self.training( curr_state_, action, n_step_return)
         #self.queue.queue.clear()
-        try:
-            torch.save(self.model, 'model_a3c.pt')
-        except:
-            print("Model is not saved!!!")
+        #try:
+        #    torch.save(self.model, 'model_a3c.pt')
+        #except:
+        #    print("Model is not saved!!!")
 
 
 class FlatConvert(FlatLink):
@@ -161,59 +185,101 @@ class FlatConvert(FlatLink):
         """
         self.queue = queue.Queue()
         buffer_list=[]
+        best_reward = -999999999
         #self.load_('model_a3c.pt')
         print("start training...")
+        number_of_agents = len(env.agents)
         for j in range(self.play_epochs):
             reward_mean= 0.0
             predicted_mean = 0.0
-            steps_mean = 0.0
+            steps_mean = np.zeros((number_of_agents))
+            solved_mean = 0.0
+            b_list=[]
             for i in range(self.play_game_number):
-                predicted_, reward_ ,steps_= self.collect_training_data(env,env_renderer)
+                predicted_, reward_ ,steps_, done_= self.collect_training_data(env,env_renderer)
                 reward_mean+=reward_
                 predicted_mean+=predicted_
-                steps_mean += steps_
-                list_ = list(self.queue.queue) 
-
-                buffer_list = buffer_list + list_ #+ [ [x,y,z] for x,y,z,d in random.sample(list_,count_)]
+                solved_mean += done_
+                steps_mean += np.floor(steps_)
+ 
+            b_list = list(self.queue.queue) 
             reward_mean /= float(self.play_game_number)
             predicted_mean /= float(self.play_game_number)
             steps_mean /= float(self.play_game_number)
-            print("play epoch: ",j,", total buffer size: ", len(buffer_list))
+            solved_mean /= float(self.play_game_number)
+            print("play epoch: ",j,", total buffer size: ", len(b_list))
+            self.trainer(buffer_list + b_list)
+            buffer_list =  buffer_list  + b_list
             if len(buffer_list) > self.replay_buffer_size * 1.2:
                 list_2 =[ [x,y,z] for x,y,z in sorted(buffer_list, key=lambda pair: pair[2],reverse=True)] 
-                size_  = int(self.replay_buffer_size/4)
-                buffer_list = random.sample(buffer_list,size_*2) + list_2[:size_] + list_2[-size_:]
-                random.shuffle(buffer_list)
+                size_  = int(self.replay_buffer_size/2)
+                buffer_list =  list_2[:size_] + list_2[-size_:]
+           
+            if reward_mean > best_reward:
+                best_reward = reward_mean
+                try:
+                    torch.save(self.model, 'model_a3c.pt')
+                except:
+                    print("Model is not saved!!!")
 
             print("play epoch: ",j,", buffer size: ", len(buffer_list),", reward (mean): ",reward_mean,
-            ", steps (mean): ",steps_mean,", predicted (mean): ",predicted_mean)
-            self.trainer(buffer_list)
+            ", solved (mean): ",solved_mean, ", steps (mean): ",steps_mean,", predicted (mean): ",predicted_mean)
+           
             self.queue.queue.clear()
             self.thres = max( self.thres*self.rate_decay,self.min_thres)
+            env.num_resets = 0
     
+    def shape_reward(self,states_,reward):
+        """
+        Additional penalty for not moving (e.g when in deadlock), only applied on last time step!
+        """
+        #if np.array_equal(states_[0,2,:,:],states_[1,2,:,:]):
+        #    reward += -0.5
+        #states_=states_[0].reshape((4,int(states_[0].shape[0]/4),states_[0].shape[1],states_[0].shape[2]))
+        #x = states_.shape[0]
+        #if np.array_equal(states_[x-2,2,:,:],states_[x-1,2,:,:]):
+        #reward += -0.5
+        return reward
 
-    def make_state(self,state):
+    def make_state(self,state_):
         """
         Stack states 3 times for initial call (t1==t2==t3)
         """
-        return np.stack((state,state,state))
+        state = state_[0] #target does not move!
+        return [np.stack((state,state,state,state)).reshape((4*state.shape[0],state.shape[1],state.shape[2])),state_[1]]
 
-    def update_state(self,states_,state):
+    def update_state(self,states_a,state):
         """
         update state by removing oldest timestep and adding actual step
         """
+        states_ = states_a[0]
+        states_=states_.reshape((4,state[0].shape[0],state[0].shape[1],state[0].shape[2]))
         states_[:(states_.shape[0]-1),:,:,:] =  states_[1:,:,:,:]
-        states_[(state.shape[0]-1):,:,:,:] = state.reshape((1,state.shape[0],state.shape[1],state.shape[2]))
-        return states_
+        states_[(state[0].shape[0]-1):,:,:,:] = state[0].reshape((1,state[0].shape[0],state[0].shape[1],state[0].shape[2]))
+        return [states_.reshape((4*state[0].shape[0],state[0].shape[1],state[0].shape[2])),state[1]]
+
+
+    def stack_observations(self,state):
+        local_rail_obs=state[0] 
+        obs_map_state=state[1] 
+        obs_other_agents_state=state[2] 
+        direction=state[3] 
+        #final_size = local_rail_obs.shape[2] + obs_map_state.shape[2] + obs_other_agents_state.shape[2]
+        #new_array = np.empty((local_rail_obs.shape[0],local_rail_obs.shape[1],final_size),dtype = local_rail_obs.dtype)
+        # new_array[:,:,:local_rail_obs.shape[2]]=local_rail_obs
+        #new_array[:,:,local_rail_obs.shape[2]:local_rail_obs.shape[2] + obs_map_state.shape[2]]=obs_map_state
+        new_array = np.dstack((local_rail_obs, obs_map_state, obs_other_agents_state))
+        new_array = np.transpose( new_array, (2, 0, 1))
+        return [new_array,direction]
 
     def init_state_dict(self,state_dict_):
         for key in state_dict_.keys():
-            state_dict_[key] = self.make_state(state_dict_[key])
+            state_dict_[key] = self.stack_observations(state_dict_[key])
         return state_dict_
 
     def update_state_dict(self,old_state_dict, new_state_dict):
         for key in old_state_dict.keys():
-            new_state_dict[key] = self.update_state(old_state_dict[key],new_state_dict[key])
+            new_state_dict[key] = self.stack_observations(new_state_dict[key])
         return new_state_dict
 
     def play(self,env,env_renderer=None,filename = 'model_a3c.pt',epochs_ =10):
@@ -257,6 +323,8 @@ class FlatConvert(FlatLink):
         return  float(pre_)/float(act_), global_reward, float(iter_)/float(self.max_iter)
 
 
+
+
     def collect_training_data(self,env,env_renderer = None):
         """
         Run single simualtion szenario for training
@@ -272,6 +340,7 @@ class FlatConvert(FlatLink):
         policy_dict = {}
         value_dict = {}
         action_dict = {}
+        iter_counts = [0] * len(curr_state_dict)
         iter_ = 0
         pre_ = 0
         act_ = 0
@@ -279,29 +348,36 @@ class FlatConvert(FlatLink):
         while not done and iter_< self.max_iter:
             iter_+=1
             for agent_id, state in curr_state_dict.items():
-                policy, value = self.predict(state.astype(np.float).reshape((1,state.shape[0],state.shape[1],state.shape[2],state.shape[3])))
+                policy, value = self.predict(state)
                 policy_dict[agent_id] = policy
                 value_dict[agent_id] = value
                 action_dict[agent_id], pre = self.get_action(policy)
                 pre_ += pre
                 act_+=1 
-            if env_renderer is not None:    
-                env_renderer.renderEnv(show=True)
+
             next_state_dict, reward_dict, done_dict, _ = env.step(action_dict)
             next_state_dict = self.update_state_dict(curr_state_dict, next_state_dict)
             #reward_dict = self.modify_reward(reward_dict,done_dict,next_state_dict)
             done = done_dict['__all__']
             for agent_id, state in curr_state_dict.items():
-            
+                if done_dict[agent_id] is not True:
+                    iter_counts[agent_id] += 1
+                if memory[agent_id] == [] and done_dict[agent_id]:
+                    continue
                 memory[agent_id].append(tuple([curr_state_dict[agent_id], 
                 action_dict[agent_id], reward_dict[agent_id], next_state_dict[agent_id], done_dict[agent_id]]))
+                reward_dict[agent_id] = self.shape_reward(next_state_dict[agent_id], reward_dict[agent_id])
                 global_reward += reward_dict[agent_id]
-                while ((len(memory[agent_id]) >= self.n_step_return_) or (done and not memory[agent_id] == [])):
+                while ((len(memory[agent_id]) >= self.n_step_return_) or (done_dict[agent_id] and not memory[agent_id] == [])):
                     curr_state, action, n_step_return, memory[agent_id] = self.accumulate(memory[agent_id], 
-                    value_dict[agent_id])
+                    value_dict[agent_id],done)
                     self.queue.put([curr_state, action, n_step_return])
 
 
             curr_state_dict = next_state_dict
 
-        return  float(pre_)/float(act_), global_reward, float(iter_)/float(self.max_iter)
\ No newline at end of file
+            if env_renderer is not None:
+                #does not work for actual masterè    
+                env_renderer.renderEnv(show=True,show_observations=False)
+
+        return  float(pre_)/float(act_), global_reward, np.array(iter_counts).astype(np.float)/float(self.max_iter), 1 if done else 0
\ No newline at end of file
diff --git a/a3c/loader.py b/a3c/loader.py
index 589f9a4..d742e5e 100644
--- a/a3c/loader.py
+++ b/a3c/loader.py
@@ -9,15 +9,16 @@ class FlatData(torch.utils.data.dataset.Dataset):
 
     """
     def __init__(self,states,targets,rewards,num_action,device,dtype):
-        self.states_ = [torch.from_numpy(x.astype(np.float)).to(device = device, dtype=dtype)  for x in states]
-        self.targets_ = [torch.from_numpy(np.eye(num_action)[x]).to(device = device, dtype=dtype)  for x in targets]
-        self.rewards_ = [torch.from_numpy(np.array([x]).reshape(-1)).to(device = device, dtype=dtype) for x in rewards]
+        self.states_ = [torch.from_numpy(x[0].astype(np.float))  for x in states]
+        self.states_2_ = [torch.from_numpy(x[1].astype(np.float))  for x in states]
+        self.targets_ = [torch.from_numpy(np.eye(num_action)[x])  for x in targets]
+        self.rewards_ = [torch.from_numpy(np.array([x]).reshape(-1)) for x in rewards]
        
 
     # Override to give PyTorch access to any image on the dataset
     def __getitem__(self, index):
         
-        return self.states_[index], self.targets_[index], self.rewards_[index]
+        return self.states_[index], self.states_2_[index], self.targets_[index], self.rewards_[index]
 
     # Override to give PyTorch size of dataset
     def __len__(self):
diff --git a/a3c/main.py b/a3c/main.py
index 24e27a6..f94ecd5 100644
--- a/a3c/main.py
+++ b/a3c/main.py
@@ -4,193 +4,24 @@ from fake import fakeEnv
 import sys
 import copy
 import numpy as np
-sys.path.append("D:\\ZHAW Master\\FS19\VT2\\flatland-master_6")
-#sys.path.insert(0, "D:\\ZHAW Master\\FS19\VT2\\flatland-master\\flatland")
+#set path to flatland here
+sys.path.append("flatland")
+
 from flatland.envs.rail_env import RailEnv
-from flatland.envs.generators import random_rail_generator, complex_rail_generator,rail_from_GridTransitionMap_generator
+from flatland.envs.generators import random_rail_generator, complex_rail_generator
+from flatland.envs.observations import LocalObsForRailEnv
 from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.transition_map import GridTransitionMap
 from flatland.utils.rendertools import RenderTool
 import time
 import pickle
-size_x=8
-size_y=8
-
-
-def save_map(rail):
-    np.save("my_rail_grid.npy", rail.grid)
-    np.save("my_rail_trans_list.npy",np.array(rail.transitions.transition_list))
-    np.save("my_rail_trans.npy",np.array(rail.transitions.transitions))
-
-def load_map():
-    lmap = GridTransitionMap(12,12)
-    lmap.grid = np.load("grid/my_rail_grid.npy")
-    lmap.transitions.transition_list =list( np.load("grid/my_rail_trans_list.npy"))
-    lmap.transitions.transitions = list(np.load("grid/my_rail_trans.npy"))
-    return lmap
-"""
-def generate_map(rail_generator=random_rail_generator()):
-
-    t = time.time()
-    geny = rail_generator
-    rail = geny(size_x,size_y)
-    elapsed = time.time() - t
-    print("initialization time: ", elapsed)
-    return rail
-"""
-
-class RawObservation(ObservationBuilder):
-    """
-    ObservationBuilder for raw observations.
-    """
-    def __init__(self, size_):
-        self.reset()
-        self.size_ = size_
-
-    def _set_env(self, env):
-        self.env = env
-
-    def reset(self):
-        """
-        Called after each environment reset.
-        """
-        self.map_ = None
-        self.agent_positions_ = None
-        self.agent_handles_ = None
-
-    def search_grid(self,x_indices,y_indices):
-        b_ = None
-        if x_indices!=[] and y_indices!=[]:
-            for i in x_indices:
-                for j in y_indices:
-                    if int(self.env.rail.grid[i][j]) != 0:
-                        b_ = [i,j]
-                        break
-        return b_
-
-    def get_nearest_in_scope(self,position_,size_,target_):
-        """
-        Crop region of size x,y around position and set a intemediate target if
-        real target is not in local window. 
-        """
-        shape_ = self.env.rail.grid.shape
-        x_ = position_[0] - target_[0]
-        y_ = position_[1] - target_[1]
-        a_ = None
-        b_ = None
-
-        if np.abs(x_) >= size_[0] or np.abs(y_) >= size_[1]:
-            xd_ = x_
-            yd_ = y_
-            if np.abs(x_) >= size_[0]:
-                xd_ = int(xd_ * size_[0]/np.abs(x_))
-            if np.abs(y_) >= size_[1]:
-                yd_ = int(yd_ * size_[1]/np.abs(y_))
-            #bug here!
-            x_indices = []
-            y_indices = []
-            if xd_ < 0:
-                x_indices  = list(reversed(range(position_[0] ,position_[0] - xd_ -1)))
-            elif xd_ > 0:
-                x_indices  = list(range(position_[0] -xd_ + 1,position_[0]))
-            if yd_ < 0: 
-                y_indices  = list(reversed(range(position_[1] ,position_[1] - yd_ -1)))
-            elif yd_ > 0:
-                y_indices  = list(range(position_[1] -yd_ +1 ,position_[1]))
-            
-            # cheat
-            x_indices.append(position_[0])
-            y_indices.append(position_[1])
-            # "nearest point in new area"
-            b_ = self.search_grid(x_indices, y_indices)
-    
-           
-            if b_ is not None:
-                target_ = b_
-            else:
-                # take any existing point
-                if len (x_indices) < size_[0]/2 -1:
-                    x_indices = list(range(position_[0] - size_x, position_[0]+size_x))
-                if len (y_indices) < size_[y]/2 -1:
-                    y_indices = list(range(position_[1] - size_y,position_[1]+size_y))
-
-                b_ = self.search_grid(x_indices, y_indices)
-                if b_ is  None:
-                    target_ = position_
-
-            x_ = position_[0] - target_[0]
-            y_ = position_[1] - target_[1]
-        
-        
-        if x_ > 0 and y_ > 0:
-            a_ =  [target_[0], target_[1]]
-        elif x_<= 0 and y_ <= 0:
-            a_ = [position_[0], position_[1]]
-        elif x_> 0 and y_ <= 0:
-            a_ = [target_[0], position_[1]]
-        elif x_<= 0 and y_ > 0:
-            a_ = [position_[0], target_[1]]
-    
-        if a_[0] + size_[0] >= shape_[0]:
-            a_[0]-= a_[0] + size_[0] - shape_[0]
-
-        if a_[1] + size_[1] >= shape_[1]:
-            a_[1]-= a_[1] +size_[1] - shape_[1]
-        
-        if position_[0]-a_[0] >= size_[0] or position_[1]-a_[1] >= size_[1]:
-             b_ = target_
-
-
-        return a_, target_
-
-    def get(self, handle=0):
-        """
-        Called whenever an observation has to be computed for the `env' environment, possibly
-        for each agent independently (agent id `handle').
-
-        Parameters
-        -------
-        handle : int (optional)
-            Handle of the agent for which to compute the observation vector.
-
-        Returns
-        -------
-        function
-        Transition map as local window of size x,y , agent-positions if in window and target.
-        """
-    
-        map_ = self.env.rail.grid.astype(np.float)/65535.0
-        target = self.env.agents_target[handle]
-        position = self.env.agents_position[handle]
-        a_, target = self.get_nearest_in_scope(position,self.size_, target)
-        map_ = map_[a_[0]:a_[0]+self.size_[0],a_[1]:a_[1]+self.size_[1]]
-
-        agent_positions_ = np.zeros_like(map_)
-        agent_handles = self.env.get_agent_handles()
-        for handle_ in agent_handles:
-                direction_ = float(self.env.agents_direction[handle_] + 1)/5.0
-                position_ = self.env.agents_position[handle_]
-                if position_[0]>= a_[0] and position_[0]<a_[0]+self.size_[0] \
-                and position_[1]>= a_[1] and position_[1]<a_[1]+self.size_[1]:
-                    agent_positions_[position_[0]-a_[0]][position_[1]-a_[1]] = direction_
-
-        my_position_ = np.zeros_like(map_)
-        
-        direction = float(self.env.agents_direction[handle] + 1)/5.0
-        my_position_[position[0]-a_[0]][position[1]-a_[1]] = direction
-
-        my_target_ = np.zeros_like(map_)
-        
-        my_target_[target[0]-a_[0]][target[1]-a_[1]] = 1
-
-        return np.stack(( map_,agent_positions_,my_position_,my_target_))
-
-
-def fake_generator(rail_l_map):
-    def generator(width, height, num_resets=0):
-
-        return copy.deepcopy(rail_l_map)
-    return generator
+#set parameters here
+size_x=20
+size_y=20
+fov_size_x= 10
+fov_size_y =10
+radius = 5
+num_agents = 10
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
@@ -198,35 +29,31 @@ if __name__ == '__main__':
     args = parser.parse_args()
     rt = None
     if args.action == 'generate':
-        #rail_l_map_ = generate_map()
-        #save_map(rail_l_map_)
-        link_ = FlatConvert(size_x,size_y)
+
+        link_ = FlatConvert(fov_size_x,fov_size_y)
+        # some parameters, can be set in linker.py directly
+        # this one here are for testing (fast on weak machine)
         link_.play_epochs = 1
-        link_.play_game_number = 1
-        link_.replay_buffer_size = 1000
-        env = RailEnv(32,32,rail_generator=complex_rail_generator(5,40,10),number_of_agents=5,obs_builder_object=RawObservation([size_x,size_y]))
-        rt = None#RenderTool(env,gl="QT")
+        link_.n_step_return_ = 50
+        link_.play_game_number = 2
+        link_.replay_buffer_size = 100
+        env = RailEnv(size_x,size_y,rail_generator=complex_rail_generator(num_agents,nr_extra=100, min_dist=20,seed=12345),number_of_agents=num_agents,obs_builder_object=LocalObsForRailEnv(radius))
+        # instantiate renderer here
+        rt = None #RenderTool(env, gl="QT")
         link_.perform_training(env,rt)
 
-    elif args.action == 'test':
-        """
-        primitive test
-        """
-       
-        link_ = FlatLink()
-        link_.collect_training_data(fakeEnv())
-        link_.trainer()
     elif args.action == 'play':
-        link_ = FlatConvert(64,64)
-        rail_l_map = load_map()
-        env = RailEnv(64,64,rail_generator=complex_rail_generator(5,10,5),number_of_agents=5,obs_builder_object=RawObservation([size_x,size_y]))
-        rt = RenderTool(env,gl="QT")
+        
+        link_ = FlatConvert(fov_size_x,fov_size_y)
+  
+        env = RailEnv(size_x,size_y,rail_generator=complex_rail_generator(num_agents,nr_extra=100, min_dist=20,seed=12345),number_of_agents=num_agents,obs_builder_object=LocalObsForRailEnv(radius))
+        rt = None# RenderTool(env,gl="QT")
         link_.play(env,rt,"models/model_a3c_i.pt")
 
     elif args.action == 'manual':
 
         link_ = FlatLink()
-        link_.load_("D:\\ZHAW Master\\FS19\\VT2\\flatland-master_6\\flatland\\baselines\\Nets\\avoid_checkpoint15000.pth")
+        link_.load_()
         #from scipy import stats
         #for j in range(10):
         #    su = 0
diff --git a/a3c/model.py b/a3c/model.py
index 6980ed9..d164578 100644
--- a/a3c/model.py
+++ b/a3c/model.py
@@ -11,43 +11,47 @@ class FlatNet(nn.Module):
     """
     def __init__(self,size_x=20,size_y=10,init_weights=True):
         super(FlatNet, self).__init__()
-        self.observations_=  1600 #size_x*size_y*4# 9216# 256
-        self.num_actions_= 4
-        self.loss_value_ = 0.01
-        self.loss_entropy_ = 0.05
-        self.epsilon_ = 1.e-10
+        self.observations_= 12804# 25600 #size_x*size_y*4# 9216# 256
+        self.num_actions_= 5
+        self.loss_value_ = 0.5
+        self.loss_entropy_ = 0.01
+        self.epsilon_ = 1.e-14
  
         self.avgpool2 = nn.AdaptiveAvgPool2d((5, 5))
-        self.avgpool3 = nn.AdaptiveAvgPool3d((1, 5, 5))
+        
         self.feat_1 = nn.Sequential(
-            nn.Conv3d(3, 16, kernel_size=5, padding=2),
-            nn.Conv3d(16, 16, kernel_size=5, padding=2),
-            nn.Conv3d(16, 16, kernel_size=3, padding=1),
+            nn.Conv2d(22, 128, kernel_size=3, padding=1),
+            nn.Conv2d(128, 128, kernel_size=3, padding=1),
+            nn.Conv2d(128, 128, kernel_size=3, padding=1),
+            #nn.MaxPool2d(kernel_size=2, stride=2)
         )
         self.feat_1b = nn.Sequential(
-            nn.Conv3d(3, 16, kernel_size=3, padding=1),
+            nn.Conv2d(22, 128, kernel_size=1, padding=0),
+            #nn.MaxPool2d(kernel_size=2, stride=2)
         )
         self.feat_2 = nn.Sequential(
-            nn.Conv3d(32, 32, kernel_size=5, padding=2),
-            nn.Conv3d(32, 32, kernel_size=5, padding=2),
+            nn.Conv2d(256, 256, kernel_size=3, padding=1),
+            nn.Conv2d(256, 256, kernel_size=3, padding=1),
+           # nn.MaxPool2d(kernel_size=2, stride=2)
         )
         self.feat_2b = nn.Sequential(
-            nn.Conv3d(32, 32, kernel_size=1, padding=0),
+            nn.Conv2d(256, 256, kernel_size=1, padding=0),
+            #nn.MaxPool2d(kernel_size=2, stride=2)
         )
-        self.classifier_ = nn.Sequential(
+        self.classifier_ = nn.Sequential( 
             nn.Linear(self.observations_, 512),
             nn.LeakyReLU(True),
             nn.Dropout(),
-            nn.Linear(512, 256),
+            nn.Linear(512, 512),
             nn.LeakyReLU(True),
         )
 
-        self.gru_ = nn.GRU(256, 256, 5) 
+        self.lstm_ = nn.LSTM(512, 512, 1) 
         self.policy_ = nn.Sequential(
-            nn.Linear(256, self.num_actions_),
+            nn.Linear(512, self.num_actions_),
             nn.Softmax(),
         )
-        self.value_ = nn.Sequential(nn.Linear(256,1))
+        self.value_ = nn.Sequential(nn.Linear(512,1))
 
 
         if init_weights==True:
@@ -82,16 +86,19 @@ class FlatNet(nn.Module):
  
     def features(self,x):
         x = F.leaky_relu(self.features_1(x),True)
+        x = F.max_pool2d(x,kernel_size=2, stride=2)
         x = F.leaky_relu(self.features_2(x),True)
-        x = F.max_pool3d(x,kernel_size=2, stride=2)
-        x = self.avgpool3(x)
+        x = F.max_pool2d(x,kernel_size=2, stride=2)
+        x = self.avgpool2(x)
         return x
 
-    def forward(self, x):
+    def forward(self, x, x2):
         x = self.features(x)
         x = x.view(x.size(0), -1)
+        x2 = x2.view(x2.size(0),-1)
+        x = torch.cat([x,x2],1)
         x = self.classifier_(x)
-        x,_ = self.gru_(x.view(len(x), 1, -1))
+        x,_ = self.lstm_(x.view(len(x), 1, -1))
         x = x.view(len(x), -1)
         p = self.policy_(x)
         v = self.value_(x)
-- 
GitLab