Commit c6f4a5a5 authored by nilabha's avatar nilabha

Add mixed trainer and pure IL trainer changes

parent 53012b7c
Pipeline #4961 passed with stage
in 23 minutes and 7 seconds
......@@ -151,8 +151,9 @@ class ImitationAgent(PPOTrainer):
# {"obs": tf.cast(tf.expand_dims(obs[a],0),tf.float32)})
# self.model.custom_loss(expert_action,{"obs": np.expand_dims(obs[a],0)},)
logits, _ = policy.model.forward({"obs": np.expand_dims(obs[a],0)}, [], None)
np.vstack([obs[a],obs[a]])
input_dict = {"obs": np.expand_dims(obs[a],0)}
input_dict['obs_flat'] = input_dict['obs']
logits, _ = policy.model.forward(input_dict, [], None)
model_logits = tf.squeeze(logits)
expert_logits = tf.cast(expert_action, tf.int32)
# expert_one_hot = tf.one_hot(expert_logits,num_outputs)
......
import numpy as np
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.agents.dqn import ApexTrainer,DQNTrainer
from ray.rllib.utils.annotations import override
from ray.rllib.agents.ppo.ppo import PPOTrainer
from ray.tune.logger import pretty_print
import numpy as np
import os
import ray
import yaml
from pathlib import Path
from ray.cluster_utils import Cluster
from ray.rllib.evaluation import MultiAgentEpisode
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.tune import run_experiments
from ray.tune.logger import TBXLogger
from ray.tune.resources import resources_to_json
from ray.tune.tune import _make_scheduler
from ray.rllib.models.tf.tf_action_dist import Categorical
tf = try_import_tf()
from ray.tune import registry
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.optimizers import SyncSamplesOptimizer
from ray.rllib.models import ModelCatalog
from utils.argparser import create_parser
from utils.loader import load_envs, load_models
# Custom wandb logger with hotfix to allow custom callbacks
from wandblogger import WandbLogger
"""
Note : This implementation has been adapted from various files in :
https://github.com/ray-project/ray/blob/master/rllib/examples
"""
from ray.rllib.policy import Policy,TFPolicy
from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
from ray.rllib.execution.train_ops import TrainOneStep
from ray.rllib.execution.metric_ops import StandardMetricsReporting
import numpy as np
from flatland.envs.agent_utils import RailAgentStatus
import sys,os
sys.path.insert(0, os.getcwd() + '/envs/expert')
from libs.cell_graph_dispatcher import CellGraphDispatcher
def adam_optimizer(policy, config):
return tf.train.AdamOptimizer(
learning_rate=0.01, epsilon=0.001)
def default_execution_plan(workers: WorkerSet, config):
# Collects experiences in parallel from multiple RolloutWorker actors.
rollouts = ParallelRollouts(workers, mode="bulk_sync")
# Combine experiences batches until we hit `train_batch_size` in size.
# Then, train the policy on those experiences and update the workers.
train_op = rollouts \
.combine(ConcatBatches(
min_batch_size=config["train_batch_size"])) \
.for_each(TrainOneStep(workers))
# Add on the standard episode reward, etc. metrics reporting. This returns
# a LocalIterator[metrics_dict] representing metrics for each train step.
return StandardMetricsReporting(train_op, workers, config)
def loss_imitation(policy, model, dist_class, train_batch):
return np.random.randint(5)
# policy = DQNTFPolicy.with_updates(name="ImitPolicy",)
ImitationTFPolicy = build_tf_policy(
name="ImitationTFPolicy",
loss_fn=loss_imitation,
optimizer_fn=adam_optimizer,
)
class ImitationAgent(PPOTrainer):
"""Policy that takes random actions and never learns."""
_name = "ImitationAgent"
@override(Trainer)
def _init(self, config, env_creator):
self.env = env_creator(config["env_config"])
self._policy = ImitationTFPolicy
action_space = self.env.action_space
dist_class, logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
self.workers = self._make_workers(
env_creator, self._policy, config, self.config["num_workers"])
self.execution_plan = default_execution_plan
self.train_exec_impl = self.execution_plan(self.workers, config)
@override(Trainer)
def _train(self):
import tensorflow as tf
policy = self.get_policy()
steps = 0
for _ in range(1):
env = self.env._env.rail_env
obs = self.env.reset()
num_outputs = env.action_space[0]
n_agents = env.get_num_agents()
dispatcher = CellGraphDispatcher(env)
# TODO : Update max_steps as per latest version
# https://gitlab.aicrowd.com/flatland/flatland-examples/blob/master/reinforcement_learning/multi_agent_training.py
# max_steps = int(4 * 2 * (env.height + env.width + (n_agents / n_cities))) - 1
max_steps = int(4 * 2 * (20 + env.height + env.width))
episode_steps = 0
episode_max_steps = 0
episode_num_agents = 0
episode_score = 0
episode_done_agents = 0
done = {}
done["__all__"] = False
# TODO: Support for batch update
# batch_size = 2
# logits, _ = policy.model.forward({"obs": np.vstack([obs[a],obs[a]])}, [], None)
for step in range(max_steps):
action_dict = dispatcher.step(env._elapsed_steps)
with tf.GradientTape() as tape:
imitation_loss = 0
active_agents = 0
for a in range(n_agents):
if not done.get(a) and obs.get(a) is not None:
active_agents += 1
expert_action = action_dict[a].value
input_dict = {"obs": np.expand_dims(obs[a],0)}
input_dict['obs_flat'] = input_dict['obs']
logits, _ = policy.model.forward(input_dict, [], None)
model_logits = tf.squeeze(logits)
expert_logits = tf.cast(expert_action, tf.int32)
action_dist = Categorical(logits, policy.model.model_config)
imitation_loss += tf.reduce_mean(-action_dist.logp(tf.expand_dims(expert_logits,0)))
imitation_loss = imitation_loss/max(active_agents,1)
gradients = tape.gradient(imitation_loss, policy.model.trainable_variables())
self.workers.local_worker().apply_gradients(gradients)
weights = ray.put(self.workers.local_worker().get_weights())
# print(self.workers.local_worker().get_weights())
for e in self.workers.remote_workers():
e.set_weights.remote(weights)
obs, all_rewards, done, info = self.env.step(action_dict)
steps += 1
# super()._train()
for agent, agent_info in info.items():
if episode_max_steps == 0:
episode_max_steps = agent_info["max_episode_steps"]
episode_num_agents = agent_info["num_agents"]
episode_steps = max(episode_steps, agent_info["agent_step"])
episode_score += agent_info["agent_score"]
if agent_info["agent_done"]:
episode_done_agents += 1
if done["__all__"]:
print(float(episode_done_agents) / n_agents)
break
return {
"episode_reward_mean": episode_score,
"timesteps_this_iter": steps,
}
if __name__ == "__main__":
# Register all necessary assets in tune registries
load_envs(os.getcwd()) # Load envs
load_models(os.getcwd()) # Load models
ppo = True
if ppo:
config_file = "small_tree_video/PPO_test.yaml"
else:
config_file = "small_tree_video/apex_test.yaml"
with open(config_file) as f:
exp = yaml.safe_load(f)
exp["config"]["eager"] = True
exp["config"]["use_pytorch"] = False
exp["config"]["log_level"] = "INFO"
verbose = 2
exp["config"]["eager_tracing"] = True
webui_host = "0.0.0.0"
# TODO should be in exp['config'] directly
exp['config']['env_config']['yaml_config'] = config_file
exp['loggers'] = [TBXLogger]
_default_config = with_common_config(
exp["config"])
ray.init(num_cpus=3,num_gpus=0)
imitation_trainer = ImitationAgent(_default_config,
env="flatland_sparse",)
ppo_trainer = PPOTrainer(_default_config,
env="flatland_sparse",)
for i in range(5):
print("== Iteration", i, "==")
trainer_type = np.random.binomial(size=1, n=1, p= 0.5)[0]
if trainer_type:
# improve the Imitation policy
print("-- Imitation --")
print(pretty_print(imitation_trainer.train()))
ppo_trainer.set_weights(imitation_trainer.get_weights())
else:
# improve the PPO policy
print("-- PPO --")
print(pretty_print(ppo_trainer.train()))
imitation_trainer.set_weights(ppo_trainer.get_weights())
print("Done: OK")
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment