diff --git a/torch_training/bla.py b/torch_training/bla.py index cd1aa13a824a58815778f0779e64739093a5a095..698b4fe9e67a3ee0de1759bc1409f379a11b20a8 100644 --- a/torch_training/bla.py +++ b/torch_training/bla.py @@ -82,6 +82,12 @@ def main(argv): agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) + with path(torch_training.Nets, "avoid_checkpoint30000.pth") as file_in: + agent.qnetwork_local.load_state_dict(torch.load(file_in)) + + demo = False + record_images = False + frame_step = 0 print("multi_agent_trainging.py (2)")