diff --git a/a3c/linker.py b/a3c/linker.py index 315a30ecb2a750a8fb27ea0fa85722cccd338f18..9c3a543aa05e5bc394b8b336dc8c9e31cf89e374 100644 --- a/a3c/linker.py +++ b/a3c/linker.py @@ -14,7 +14,7 @@ class FlatLink(): Training: Generate training data, accumulate n-step reward and train model. Play: Execute trained model. """ - def __init__(self,size_x=20,size_y=10, gpu_ = 0): + def __init__(self,size_x=20,size_y=10, gpu_ = -1): self.n_step_return_ = 100 self.gamma_ = 0.99 self.batch_size = 4000 @@ -175,7 +175,7 @@ class FlatLink(): class FlatConvert(FlatLink): - def __init__(self,size_x,size_y, gpu_ = 0): + def __init__(self,size_x,size_y, gpu_ = -1): super(FlatConvert, self).__init__(size_x,size_y,gpu_) @@ -305,14 +305,21 @@ class FlatConvert(FlatLink): while not done and iter_< self.max_iter: iter_+=1 for agent_id, state in curr_state_dict.items(): - policy, value = self.predict(state.astype(np.float).reshape((1,state.shape[0],state.shape[1],state.shape[2],state.shape[3]))) + # print ("agent: ", agent_id, " state: ", type(state), len(state)) + # for iState, elState in enumerate(state): + # print("state element:", iState, type(elState), "shape", elState.shape) + # arrState = np.array(state) + # print(arrState.size, arrState.shape) + #policy, value = self.predict(state.astype(np.float).reshape((1,state.shape[0],state.shape[1],state.shape[2],state.shape[3]))) + policy, value = self.predict(state) + policy_dict[agent_id] = policy value_dict[agent_id] = value action_dict[agent_id], pre = self.get_action(policy) pre_ += pre act_+=1 if env_renderer is not None: - env_renderer.renderEnv(show=True) + env_renderer.render_env(show=True, show_observations=False) next_state_dict, reward_dict, done_dict, _ = env.step(action_dict) next_state_dict = self.update_state_dict(curr_state_dict, next_state_dict) @@ -378,6 +385,6 @@ class FlatConvert(FlatLink): if env_renderer is not None: #does not work for actual masterè - env_renderer.renderEnv(show=True,show_observations=False) + env_renderer.render_env(show=True,show_observations=False) return float(pre_)/float(act_), global_reward, np.array(iter_counts).astype(np.float)/float(self.max_iter), 1 if done else 0 \ No newline at end of file diff --git a/a3c/main.py b/a3c/main.py index f94ecd566535e02eb71a87de81828829eba14bbd..a0b832e091643cf35cb05196f79be7b969c01c17 100644 --- a/a3c/main.py +++ b/a3c/main.py @@ -15,6 +15,7 @@ from flatland.core.transition_map import GridTransitionMap from flatland.utils.rendertools import RenderTool import time import pickle +import torch.cuda #set parameters here size_x=20 size_y=20 @@ -27,10 +28,13 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('action', choices=['generate', 'test','play','manual']) args = parser.parse_args() + have_cuda = torch.cuda.is_available() + gpu = 0 if have_cuda else -1 + rt = None if args.action == 'generate': - link_ = FlatConvert(fov_size_x,fov_size_y) + link_ = FlatConvert(fov_size_x,fov_size_y, gpu_=gpu) # some parameters, can be set in linker.py directly # this one here are for testing (fast on weak machine) link_.play_epochs = 1 @@ -39,20 +43,20 @@ if __name__ == '__main__': link_.replay_buffer_size = 100 env = RailEnv(size_x,size_y,rail_generator=complex_rail_generator(num_agents,nr_extra=100, min_dist=20,seed=12345),number_of_agents=num_agents,obs_builder_object=LocalObsForRailEnv(radius)) # instantiate renderer here - rt = None #RenderTool(env, gl="QT") + rt = RenderTool(env, gl="PILSVG") link_.perform_training(env,rt) elif args.action == 'play': - link_ = FlatConvert(fov_size_x,fov_size_y) + link_ = FlatConvert(fov_size_x,fov_size_y, gpu_=gpu) env = RailEnv(size_x,size_y,rail_generator=complex_rail_generator(num_agents,nr_extra=100, min_dist=20,seed=12345),number_of_agents=num_agents,obs_builder_object=LocalObsForRailEnv(radius)) - rt = None# RenderTool(env,gl="QT") + rt = RenderTool(env,gl="PILSVG") link_.play(env,rt,"models/model_a3c_i.pt") elif args.action == 'manual': - link_ = FlatLink() + link_ = FlatLink(gpu_=gpu) link_.load_() #from scipy import stats #for j in range(10):