diff --git a/train.py b/train.py index 59db8d025c9942a8a2cf0c80058cb63aee89a369..4e887c8137f1afaba8673dcc15909e0560d6db07 100644 --- a/train.py +++ b/train.py @@ -86,7 +86,7 @@ def train(config): "policies_to_train": list(policy_graphs.keys())} agent_config["horizon"] = 50 - ppo_trainer = PPOAgent(env=f"railenv_", config=agent_config) + ppo_trainer = PPOAgent(env=f"railenv", config=agent_config) for i in range(100000 + 2): print("== Iteration", i, "==")