diff --git a/grid_search_train.py b/grid_search_train.py index 0689ec03c86b178fda92974b846ffd7261587084..0f04d1c512b96c07988b0bb324e57aab6a90bef4 100644 --- a/grid_search_train.py +++ b/grid_search_train.py @@ -1,15 +1,13 @@ from baselines.RailEnvRLLibWrapper import RailEnvRLLibWrapper -import random import gym from flatland.envs.generators import complex_rail_generator import ray.rllib.agents.ppo.ppo as ppo -from ray.rllib.agents.ppo.ppo import PPOAgent +from ray.rllib.agents.ppo.ppo import PPOTrainer from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph -from ray.tune.registry import register_env from ray.rllib.models import ModelCatalog from ray.tune.logger import pretty_print from ray.rllib.models.preprocessors import Preprocessor @@ -38,8 +36,6 @@ ray.init() def train(config, reporter): print('Init Env') - env_name = f"rail_env_{config['n_agents']}" # To modify if different environments configs are explored. - transition_probability = [15, # empty cell - Case 0 5, # Case 1 - straight 5, # Case 2 - simple switch @@ -59,10 +55,10 @@ def train(config, reporter): rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), number_of_agents=1) """ - env = RailEnv(width=config['map_width'], - height=config['map_height'], - rail_generator=complex_rail_generator(nr_start_goal=50, min_dist=5, max_dist=99999, seed=0), - number_of_agents=config['n_agents']) + env_config = {"width":config['map_width'], + "height":config['map_height'], + "rail_generator":complex_rail_generator(nr_start_goal=50, min_dist=5, max_dist=99999, seed=0), + "number_of_agents":config['n_agents']} """ env = RailEnv(width=20, height=20, @@ -79,8 +75,6 @@ def train(config, reporter): # rail_generator=complex_rail_generator(nr_start_goal=config["n_agents"], nr_extra=20, min_dist=12), # number_of_agents=config["n_agents"]) - register_env(env_name, lambda _: env) - obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,)) act_space = gym.spaces.Discrete(4) @@ -99,15 +93,16 @@ def train(config, reporter): "policies_to_train": list(policy_graphs.keys())} agent_config["horizon"] = config['horizon'] - agent_config["num_workers"] = 0 - agent_config["num_cpus_per_worker"] = 10 - agent_config["num_gpus"] = 0.5 - agent_config["num_gpus_per_worker"] = 0.5 - agent_config["num_cpus_for_driver"] = 1 - agent_config["num_envs_per_worker"] = 10 + # agent_config["num_workers"] = 0 + # agent_config["num_cpus_per_worker"] = 10 + # agent_config["num_gpus"] = 0.5 + # agent_config["num_gpus_per_worker"] = 0.5 + # agent_config["num_cpus_for_driver"] = 1 + # agent_config["num_envs_per_worker"] = 10 + agent_config["env_config"] = env_config agent_config["batch_mode"] = "complete_episodes" - ppo_trainer = PPOAgent(env=env_name, config=agent_config) + ppo_trainer = PPOTrainer(env=RailEnvRLLibWrapper, config=agent_config) for i in range(100000 + 2): print("== Iteration", i, "==")