diff --git a/grid_search_train.py b/grid_search_train.py
index 0689ec03c86b178fda92974b846ffd7261587084..0f04d1c512b96c07988b0bb324e57aab6a90bef4 100644
--- a/grid_search_train.py
+++ b/grid_search_train.py
@@ -1,15 +1,13 @@
 from baselines.RailEnvRLLibWrapper import RailEnvRLLibWrapper
-import random
 import gym
 
 
 from flatland.envs.generators import complex_rail_generator
 
 import ray.rllib.agents.ppo.ppo as ppo
-from ray.rllib.agents.ppo.ppo import PPOAgent
+from ray.rllib.agents.ppo.ppo import PPOTrainer
 from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
 
-from ray.tune.registry import register_env
 from ray.rllib.models import ModelCatalog
 from ray.tune.logger import pretty_print
 from ray.rllib.models.preprocessors import Preprocessor
@@ -38,8 +36,6 @@ ray.init()
 def train(config, reporter):
     print('Init Env')
 
-    env_name = f"rail_env_{config['n_agents']}"  # To modify if different environments configs are explored.
-
     transition_probability = [15,  # empty cell - Case 0
                               5,  # Case 1 - straight
                               5,  # Case 2 - simple switch
@@ -59,10 +55,10 @@ def train(config, reporter):
                   rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
                   number_of_agents=1)
     """
-    env = RailEnv(width=config['map_width'],
-                  height=config['map_height'],
-                  rail_generator=complex_rail_generator(nr_start_goal=50, min_dist=5, max_dist=99999, seed=0),
-                  number_of_agents=config['n_agents'])
+    env_config = {"width":config['map_width'],
+                  "height":config['map_height'],
+                  "rail_generator":complex_rail_generator(nr_start_goal=50, min_dist=5, max_dist=99999, seed=0),
+                  "number_of_agents":config['n_agents']}
     """
     env = RailEnv(width=20,
                   height=20,
@@ -79,8 +75,6 @@ def train(config, reporter):
     #               rail_generator=complex_rail_generator(nr_start_goal=config["n_agents"], nr_extra=20, min_dist=12),
     #               number_of_agents=config["n_agents"])
 
-    register_env(env_name, lambda _: env)
-
     obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,))
     act_space = gym.spaces.Discrete(4)
 
@@ -99,15 +93,16 @@ def train(config, reporter):
                                   "policies_to_train": list(policy_graphs.keys())}
     agent_config["horizon"] = config['horizon']
 
-    agent_config["num_workers"] = 0
-    agent_config["num_cpus_per_worker"] = 10
-    agent_config["num_gpus"] = 0.5
-    agent_config["num_gpus_per_worker"] = 0.5
-    agent_config["num_cpus_for_driver"] = 1
-    agent_config["num_envs_per_worker"] = 10
+    # agent_config["num_workers"] = 0
+    # agent_config["num_cpus_per_worker"] = 10
+    # agent_config["num_gpus"] = 0.5
+    # agent_config["num_gpus_per_worker"] = 0.5
+    # agent_config["num_cpus_for_driver"] = 1
+    # agent_config["num_envs_per_worker"] = 10
+    agent_config["env_config"] = env_config
     agent_config["batch_mode"] = "complete_episodes"
 
-    ppo_trainer = PPOAgent(env=env_name, config=agent_config)
+    ppo_trainer = PPOTrainer(env=RailEnvRLLibWrapper, config=agent_config)
 
     for i in range(100000 + 2):
         print("== Iteration", i, "==")