diff --git a/a3c/linker.py b/a3c/linker.py
index 315a30ecb2a750a8fb27ea0fa85722cccd338f18..9c3a543aa05e5bc394b8b336dc8c9e31cf89e374 100644
--- a/a3c/linker.py
+++ b/a3c/linker.py
@@ -14,7 +14,7 @@ class FlatLink():
     Training: Generate training data, accumulate n-step reward and train model.
     Play: Execute trained model.
     """
-    def __init__(self,size_x=20,size_y=10, gpu_ = 0):
+    def __init__(self,size_x=20,size_y=10, gpu_ = -1):
         self.n_step_return_ = 100
         self.gamma_ = 0.99
         self.batch_size = 4000
@@ -175,7 +175,7 @@ class FlatLink():
 
 
 class FlatConvert(FlatLink):
-    def __init__(self,size_x,size_y, gpu_ = 0):
+    def __init__(self,size_x,size_y, gpu_ = -1):
         super(FlatConvert, self).__init__(size_x,size_y,gpu_)
 
 
@@ -305,14 +305,21 @@ class FlatConvert(FlatLink):
             while not done and iter_< self.max_iter:
                 iter_+=1
                 for agent_id, state in curr_state_dict.items():
-                    policy, value = self.predict(state.astype(np.float).reshape((1,state.shape[0],state.shape[1],state.shape[2],state.shape[3])))
+                    # print ("agent: ", agent_id, " state: ", type(state), len(state))
+                    # for iState, elState in enumerate(state):
+                    #    print("state element:", iState, type(elState), "shape", elState.shape)
+                    # arrState = np.array(state)
+                    # print(arrState.size, arrState.shape)
+                    #policy, value = self.predict(state.astype(np.float).reshape((1,state.shape[0],state.shape[1],state.shape[2],state.shape[3])))
+                    policy, value = self.predict(state)
+
                     policy_dict[agent_id] = policy
                     value_dict[agent_id] = value
                     action_dict[agent_id], pre = self.get_action(policy)
                     pre_ += pre
                     act_+=1 
                 if env_renderer is not None:    
-                    env_renderer.renderEnv(show=True)
+                    env_renderer.render_env(show=True, show_observations=False)
                 next_state_dict, reward_dict, done_dict, _ = env.step(action_dict)
                 
                 next_state_dict = self.update_state_dict(curr_state_dict, next_state_dict)
@@ -378,6 +385,6 @@ class FlatConvert(FlatLink):
 
             if env_renderer is not None:
                 #does not work for actual masterè    
-                env_renderer.renderEnv(show=True,show_observations=False)
+                env_renderer.render_env(show=True,show_observations=False)
 
         return  float(pre_)/float(act_), global_reward, np.array(iter_counts).astype(np.float)/float(self.max_iter), 1 if done else 0
\ No newline at end of file
diff --git a/a3c/main.py b/a3c/main.py
index f94ecd566535e02eb71a87de81828829eba14bbd..a0b832e091643cf35cb05196f79be7b969c01c17 100644
--- a/a3c/main.py
+++ b/a3c/main.py
@@ -15,6 +15,7 @@ from flatland.core.transition_map import GridTransitionMap
 from flatland.utils.rendertools import RenderTool
 import time
 import pickle
+import torch.cuda
 #set parameters here
 size_x=20
 size_y=20
@@ -27,10 +28,13 @@ if __name__ == '__main__':
     parser = argparse.ArgumentParser()
     parser.add_argument('action', choices=['generate', 'test','play','manual'])
     args = parser.parse_args()
+    have_cuda = torch.cuda.is_available()
+    gpu = 0 if have_cuda else -1
+
     rt = None
     if args.action == 'generate':
 
-        link_ = FlatConvert(fov_size_x,fov_size_y)
+        link_ = FlatConvert(fov_size_x,fov_size_y, gpu_=gpu)
         # some parameters, can be set in linker.py directly
         # this one here are for testing (fast on weak machine)
         link_.play_epochs = 1
@@ -39,20 +43,20 @@ if __name__ == '__main__':
         link_.replay_buffer_size = 100
         env = RailEnv(size_x,size_y,rail_generator=complex_rail_generator(num_agents,nr_extra=100, min_dist=20,seed=12345),number_of_agents=num_agents,obs_builder_object=LocalObsForRailEnv(radius))
         # instantiate renderer here
-        rt = None #RenderTool(env, gl="QT")
+        rt = RenderTool(env, gl="PILSVG")
         link_.perform_training(env,rt)
 
     elif args.action == 'play':
         
-        link_ = FlatConvert(fov_size_x,fov_size_y)
+        link_ = FlatConvert(fov_size_x,fov_size_y, gpu_=gpu)
   
         env = RailEnv(size_x,size_y,rail_generator=complex_rail_generator(num_agents,nr_extra=100, min_dist=20,seed=12345),number_of_agents=num_agents,obs_builder_object=LocalObsForRailEnv(radius))
-        rt = None# RenderTool(env,gl="QT")
+        rt = RenderTool(env,gl="PILSVG")
         link_.play(env,rt,"models/model_a3c_i.pt")
 
     elif args.action == 'manual':
 
-        link_ = FlatLink()
+        link_ = FlatLink(gpu_=gpu)
         link_.load_()
         #from scipy import stats
         #for j in range(10):