diff --git a/grid_search_train.py b/grid_search_train.py index 16b0a421bf119ff75bd3afa6c1a273b0e169ca24..919c012d4cfb5fd81b02ce704a2aa8ecca1c74b5 100644 --- a/grid_search_train.py +++ b/grid_search_train.py @@ -78,6 +78,14 @@ def train(config, reporter): "policies_to_train": list(policy_graphs.keys())} agent_config["horizon"] = config['horizon'] + agent_config["num_workers"] = 0 + agent_config["num_cpus_per_worker"] = 10 + agent_config["num_gpus"] = 0.5 + agent_config["num_gpus_per_worker"] = 0.5 + agent_config["num_cpus_for_driver"] = 1 + agent_config["num_envs_per_worker"] = 10 + agent_config["batch_mode"] = "complete_episodes" + ppo_trainer = PPOAgent(env=env_name, config=agent_config) for i in range(100000 + 2):