diff --git a/RailEnvRLLibWrapper.py b/RailEnvRLLibWrapper.py
index 1b537a6ae324d1ba34c62894c8407ce4f4b3ba70..fef562deaf36be9b4a637f8ba3bad6f90f002ceb 100644
--- a/RailEnvRLLibWrapper.py
+++ b/RailEnvRLLibWrapper.py
@@ -6,15 +6,15 @@ from flatland.envs.generators import random_rail_generator
 
 class RailEnvRLLibWrapper(RailEnv, MultiAgentEnv):
 
-    def __init__(self,
-                 width,
-                 height,
-                 rail_generator=random_rail_generator(),
-                 number_of_agents=1,
-                 obs_builder_object=TreeObsForRailEnv(max_depth=2)):
+    def __init__(self, config):
+                 # width,
+                 # height,
+                 # rail_generator=random_rail_generator(),
+                 # number_of_agents=1,
+                 # obs_builder_object=TreeObsForRailEnv(max_depth=2)):
 
-        super(RailEnvRLLibWrapper, self).__init__(width=width, height=height, rail_generator=rail_generator,
-                number_of_agents=number_of_agents, obs_builder_object=obs_builder_object)
+        super(RailEnvRLLibWrapper, self).__init__(width=config["width"], height=config["height"], rail_generator=config["rail_generator"],
+                number_of_agents=config["number_of_agents"])
 
     def reset(self, regen_rail=True, replace_agents=True):
         self.agents_done = []
diff --git a/train.py b/train.py
index 686ae5d0ea8c95e0b55dc663f6adcc90b99b72fe..320e4947ada6909cd29447e04df7667644063c54 100644
--- a/train.py
+++ b/train.py
@@ -11,7 +11,7 @@ from flatland.envs.generators import complex_rail_generator
 
 import ray.rllib.agents.ppo.ppo as ppo
 import ray.rllib.agents.dqn.dqn as dqn
-from ray.rllib.agents.ppo.ppo import PPOAgent
+from ray.rllib.agents.ppo.ppo import PPOTrainer
 from ray.rllib.agents.dqn.dqn import DQNAgent
 from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
 from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
@@ -64,10 +64,10 @@ def train(config):
                   rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
                   number_of_agents=1)
     """
-    env = RailEnvRLLibWrapper(width=20,
-                  height=20,
-                  rail_generator=complex_rail_generator(nr_start_goal=5, min_dist=5, max_dist=99999, seed=0),
-                  number_of_agents=5)
+    env_config = {"width": 20,
+                  "height":20,
+                  "rail_generator":complex_rail_generator(nr_start_goal=5, min_dist=5, max_dist=99999, seed=0),
+                  "number_of_agents":5}
     """
     env = RailEnv(width=20,
                   height=20,
@@ -77,7 +77,6 @@ def train(config):
 
     """
 
-    register_env("railenv", lambda _: env)
     # if config['render']:
     #     env_renderer = RenderTool(env, gl="QT")
     # plt.figure(figsize=(5,5))
@@ -105,9 +104,10 @@ def train(config):
    # agent_config["num_gpus_per_worker"] = 2.0
     agent_config["num_cpus_for_driver"] = 5
     agent_config["num_envs_per_worker"] = 15
+    agent_config["env_config"] = env_config
     #agent_config["batch_mode"] = "complete_episodes"
 
-    ppo_trainer = PPOAgent(env=f"railenv", config=agent_config)
+    ppo_trainer = PPOTrainer(env=f"railenv", config=agent_config)
 
     for i in range(100000 + 2):
         print("== Iteration", i, "==")