diff --git a/train.py b/train.py index 320e4947ada6909cd29447e04df7667644063c54..2f49720f4fe943554cc843f48a1d2f8696b82b49 100644 --- a/train.py +++ b/train.py @@ -12,7 +12,7 @@ from flatland.envs.generators import complex_rail_generator import ray.rllib.agents.ppo.ppo as ppo import ray.rllib.agents.dqn.dqn as dqn from ray.rllib.agents.ppo.ppo import PPOTrainer -from ray.rllib.agents.dqn.dqn import DQNAgent +from ray.rllib.agents.dqn.dqn import DQNTrainer from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph @@ -101,13 +101,13 @@ def train(config): #agent_config["num_workers"] = 0 #agent_config["num_cpus_per_worker"] = 40 #agent_config["num_gpus"] = 2.0 - # agent_config["num_gpus_per_worker"] = 2.0 - agent_config["num_cpus_for_driver"] = 5 - agent_config["num_envs_per_worker"] = 15 + #agent_config["num_gpus_per_worker"] = 2.0 + #agent_config["num_cpus_for_driver"] = 5 + #agent_config["num_envs_per_worker"] = 15 agent_config["env_config"] = env_config - #agent_config["batch_mode"] = "complete_episodes" + agent_config["batch_mode"] = "complete_episodes" - ppo_trainer = PPOTrainer(env=f"railenv", config=agent_config) + ppo_trainer = PPOTrainer(env=RailEnvRLLibWrapper, config=agent_config) for i in range(100000 + 2): print("== Iteration", i, "==")