train.py 4.05 KiB
from flatland.envs import rail_env
from flatland.envs.rail_env import random_rail_generator
from baselines.RailEnvRLLibWrapper import RailEnvRLLibWrapper
from flatland.utils.rendertools import RenderTool
import random
import gym
import matplotlib.pyplot as plt
from flatland.envs.generators import complex_rail_generator
import ray.rllib.agents.ppo.ppo as ppo
import ray.rllib.agents.dqn.dqn as dqn
from ray.rllib.agents.ppo.ppo import PPOTrainer
from ray.rllib.agents.dqn.dqn import DQNTrainer
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
from ray.tune.registry import register_env
from ray.rllib.models import ModelCatalog
from ray.tune.logger import pretty_print
from baselines.CustomPreprocessor import CustomPreprocessor
import ray
import numpy as np
from ray.rllib.env.multi_agent_env import MultiAgentEnv
# RailEnv.__bases__ = (RailEnv.__bases__[0], MultiAgentEnv)
ModelCatalog.register_custom_preprocessor("my_prep", CustomPreprocessor)
ray.init()
def train(config):
print('Init Env')
random.seed(1)
np.random.seed(1)
transition_probability = [15, # empty cell - Case 0
5, # Case 1 - straight
5, # Case 2 - simple switch
1, # Case 3 - diamond crossing
1, # Case 4 - single slip
1, # Case 5 - double slip
1, # Case 6 - symmetrical
0, # Case 7 - dead end
1, # Case 1b (8) - simple turn right
1, # Case 1c (9) - simple turn left
1] # Case 2b (10) - simple switch mirrored
# Example generate a random rail
"""
env = RailEnv(width=10,
height=10,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=1)
"""
env_config = {"width": 20,
"height":20,
"rail_generator":complex_rail_generator(nr_start_goal=5, min_dist=5, max_dist=99999, seed=0),
"number_of_agents":5}
"""
env = RailEnv(width=20,
height=20,
rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(
['../notebooks/temp.npy']),
number_of_agents=3)
"""
# if config['render']:
# env_renderer = RenderTool(env, gl="QT")
# plt.figure(figsize=(5,5))
obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,))
act_space = gym.spaces.Discrete(4)
# Dict with the different policies to train
policy_graphs = {
"ppo_policy": (PPOPolicyGraph, obs_space, act_space, {})
}
def policy_mapping_fn(agent_id):
return f"ppo_policy"
agent_config = ppo.DEFAULT_CONFIG.copy()
agent_config['model'] = {"fcnet_hiddens": [32, 32], "custom_preprocessor": "my_prep"}
agent_config['multiagent'] = {"policy_graphs": policy_graphs,
"policy_mapping_fn": policy_mapping_fn,
"policies_to_train": list(policy_graphs.keys())}
agent_config["horizon"] = 50
agent_config["num_workers"] = 0
# agent_config["sample_batch_size"]: 1000
#agent_config["num_cpus_per_worker"] = 40
#agent_config["num_gpus"] = 2.0
#agent_config["num_gpus_per_worker"] = 2.0
#agent_config["num_cpus_for_driver"] = 5
#agent_config["num_envs_per_worker"] = 15
agent_config["env_config"] = env_config
#agent_config["batch_mode"] = "complete_episodes"
ppo_trainer = PPOTrainer(env=RailEnvRLLibWrapper, config=agent_config)
for i in range(100000 + 2):
print("== Iteration", i, "==")
print("-- PPO --")
print(pretty_print(ppo_trainer.train()))
# if i % config['save_every'] == 0:
# checkpoint = ppo_trainer.save()
# print("checkpoint saved at", checkpoint)
train({})