Skip to content
Snippets Groups Projects
Commit 1550c6ff authored by gmollard's avatar gmollard
Browse files

grid search up to date

parent 7573dc58
No related branches found
No related tags found
No related merge requests found
from baselines.RailEnvRLLibWrapper import RailEnvRLLibWrapper
import random
import gym
from flatland.envs.generators import complex_rail_generator
import ray.rllib.agents.ppo.ppo as ppo
from ray.rllib.agents.ppo.ppo import PPOAgent
from ray.rllib.agents.ppo.ppo import PPOTrainer
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
from ray.tune.registry import register_env
from ray.rllib.models import ModelCatalog
from ray.tune.logger import pretty_print
from ray.rllib.models.preprocessors import Preprocessor
......@@ -38,8 +36,6 @@ ray.init()
def train(config, reporter):
print('Init Env')
env_name = f"rail_env_{config['n_agents']}" # To modify if different environments configs are explored.
transition_probability = [15, # empty cell - Case 0
5, # Case 1 - straight
5, # Case 2 - simple switch
......@@ -59,10 +55,10 @@ def train(config, reporter):
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=1)
"""
env = RailEnv(width=config['map_width'],
height=config['map_height'],
rail_generator=complex_rail_generator(nr_start_goal=50, min_dist=5, max_dist=99999, seed=0),
number_of_agents=config['n_agents'])
env_config = {"width":config['map_width'],
"height":config['map_height'],
"rail_generator":complex_rail_generator(nr_start_goal=50, min_dist=5, max_dist=99999, seed=0),
"number_of_agents":config['n_agents']}
"""
env = RailEnv(width=20,
height=20,
......@@ -79,8 +75,6 @@ def train(config, reporter):
# rail_generator=complex_rail_generator(nr_start_goal=config["n_agents"], nr_extra=20, min_dist=12),
# number_of_agents=config["n_agents"])
register_env(env_name, lambda _: env)
obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,))
act_space = gym.spaces.Discrete(4)
......@@ -99,15 +93,16 @@ def train(config, reporter):
"policies_to_train": list(policy_graphs.keys())}
agent_config["horizon"] = config['horizon']
agent_config["num_workers"] = 0
agent_config["num_cpus_per_worker"] = 10
agent_config["num_gpus"] = 0.5
agent_config["num_gpus_per_worker"] = 0.5
agent_config["num_cpus_for_driver"] = 1
agent_config["num_envs_per_worker"] = 10
# agent_config["num_workers"] = 0
# agent_config["num_cpus_per_worker"] = 10
# agent_config["num_gpus"] = 0.5
# agent_config["num_gpus_per_worker"] = 0.5
# agent_config["num_cpus_for_driver"] = 1
# agent_config["num_envs_per_worker"] = 10
agent_config["env_config"] = env_config
agent_config["batch_mode"] = "complete_episodes"
ppo_trainer = PPOAgent(env=env_name, config=agent_config)
ppo_trainer = PPOTrainer(env=RailEnvRLLibWrapper, config=agent_config)
for i in range(100000 + 2):
print("== Iteration", i, "==")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment