diff --git a/baselines/train.py b/baselines/train.py new file mode 100644 index 0000000000000000000000000000000000000000..5879052ffa9876a8ed11224e179d329bed2cd979 --- /dev/null +++ b/baselines/train.py @@ -0,0 +1,115 @@ +from flatland.envs import rail_env +from flatland.envs.rail_env import random_rail_generator +from baselines.RailEnvRLLibWrapper import RailEnvRLLibWrapper +from flatland.utils.rendertools import RenderTool +import random +import gym + +import matplotlib.pyplot as plt + +from flatland.envs.generators import complex_rail_generator + +import ray.rllib.agents.ppo.ppo as ppo +import ray.rllib.agents.dqn.dqn as dqn +from ray.rllib.agents.ppo.ppo import PPOAgent +from ray.rllib.agents.dqn.dqn import DQNAgent +from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph +from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph + +from ray.tune.registry import register_env +from ray.rllib.models import ModelCatalog +from ray.tune.logger import pretty_print +from ray.rllib.models.preprocessors import Preprocessor + + +import ray +import numpy as np + +from ray.rllib.env.multi_agent_env import MultiAgentEnv + +# RailEnv.__bases__ = (RailEnv.__bases__[0], MultiAgentEnv) + + +class MyPreprocessorClass(Preprocessor): + def _init_shape(self, obs_space, options): + return (105,) + + def transform(self, observation): + return observation # return the preprocessed observation + +ModelCatalog.register_custom_preprocessor("my_prep", MyPreprocessorClass) +ray.init() + +def train(config): + print('Init Env') + random.seed(1) + np.random.seed(1) + + # Example generate a rail given a manual specification, + # a map of tuples (cell_type, rotation) + transition_probability = [0.5, # empty cell - Case 0 + 1.0, # Case 1 - straight + 1.0, # Case 2 - simple switch + 0.3, # Case 3 - diamond drossing + 0.5, # Case 4 - single slip + 0.5, # Case 5 - double slip + 0.2, # Case 6 - symmetrical + 0.0] # Case 7 - dead end + + + + # Example generate a random rail + env = RailEnvRLLibWrapper(width=15, height=15, + rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=20, min_dist=12), + number_of_agents=10) + + register_env("railenv", lambda _: env) + # if config['render']: + # env_renderer = RenderTool(env, gl="QT") + # plt.figure(figsize=(5,5)) + + obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,)) + act_space = gym.spaces.Discrete(4) + + # Dict with the different policies to train + policy_graphs = { + "ppo_policy": (PPOPolicyGraph, obs_space, act_space, {}) + } + + def policy_mapping_fn(agent_id): + return f"ppo_policy" + + agent_config = ppo.DEFAULT_CONFIG.copy() + agent_config['model'] = {"fcnet_hiddens": [32, 32], "custom_preprocessor": "my_prep"} + agent_config['multiagent'] = {"policy_graphs": policy_graphs, + "policy_mapping_fn": policy_mapping_fn, + "policies_to_train": list(policy_graphs.keys())} + agent_config["horizon"] = 50 + agent_config["num_workers"] = 0 + agent_config["num_workers"] = 0 + agent_config["num_cpus_per_worker"] = 40 + agent_config["num_gpus"] = 2.0 + agent_config["num_gpus_per_worker"] = 2.0 + agent_config["num_cpus_for_driver"] = 5 + agent_config["num_envs_per_worker"] = 20 + #agent_config["batch_mode"] = "complete_episodes" + + ppo_trainer = PPOAgent(env=f"railenv", config=agent_config) + + for i in range(100000 + 2): + print("== Iteration", i, "==") + + print("-- PPO --") + print(pretty_print(ppo_trainer.train())) + + # if i % config['save_every'] == 0: + # checkpoint = ppo_trainer.save() + # print("checkpoint saved at", checkpoint) + +train({}) + + + + + +