Commit 546f171e authored by metataro's avatar metataro

flatland_sparse env (with fixed parameters) and experiment with 32x32 grids

parent a40e67f1
import gym
from flatland.envs.malfunction_generators import malfunction_from_params
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
from ray.rllib import MultiAgentEnv
from envs.flatland.observations import make_obs
from envs.flatland.rllib_wrapper import FlatlandRllibWrapper
class FlatlandSparse(MultiAgentEnv):
def __init__(self, env_config) -> None:
super().__init__()
self._config = env_config
self._observation = make_obs(env_config['observation'], env_config.get('observation_config'))
self._env = FlatlandRllibWrapper(rail_env=self._launch(), render=env_config['render'],
regenerate_rail_on_reset=env_config['regenerate_rail_on_reset'],
regenerate_schedule_on_reset=env_config['regenerate_schedule_on_reset'])
@property
def observation_space(self) -> gym.spaces.Space:
return self._observation.observation_space()
@property
def action_space(self) -> gym.spaces.Space:
return self._env.action_space
def _launch(self):
rail_generator = sparse_rail_generator(seed=self._config['seed'], max_num_cities=self._config['max_num_cities'],
grid_mode=self._config['grid_mode'],
max_rails_between_cities=self._config['max_rails_between_cities'],
max_rails_in_city=self._config['max_rails_in_city'])
stochastic_data = {'malfunction_rate': self._config['malfunction_rate'],
'min_duration': self._config['malfunction_min_duration'],
'max_duration': self._config['malfunction_max_duration']}
schedule_generator = sparse_schedule_generator({float(k): float(v)
for k, v in self._config['speed_ratio_map'].items()})
env = RailEnv(width=self._config['width'], height=self._config['height'], rail_generator=rail_generator,
schedule_generator=schedule_generator, number_of_agents=self._config['number_of_agents'],
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=self._observation.builder(), remove_agents_at_target=False)
return env
def step(self, action_dict):
return self._env.step(action_dict)
def reset(self):
return self._env.reset()
# Global observation convnet experiments
https://app.wandb.ai/masterscrat/flatland/reports/Flatland-Sparse-32x32--Vmlldzo5MzQ3Nw/edit
## Method
In this experiment, we compare the performance of two established CNN architectures on the global
observations. In the first case, agents are based on the Nature-CNN architecture [2] that
consists of 3 convolutional layers followed by a dense layer. In the second case, the
agents are based on the IMPALA-CNN [1] network, which consists of a 15-layer residual architecture
neural network followed by a dense layer. Agents share the same centralized
policy network.
## Results
TODO
## Plots
TODO
## Conclusion
TODO
## Refrences
[1] Lasse Espeholt et al. “IMPALA: Scalable Distributed Deep-RL with Importance
Weighted Actor-Learner Architectures”. In: Proceedings of the 35th International
Conference on Machine Learning. Vol. 80. 2018, pp. 1407–1416. URL: [https://arxiv.org/abs/1802.01561](https://arxiv.org/abs/1802.01561)
[2] Volodymyr Mnih et al. “Human-level control through deep reinforcement learn-
ing”. In: Nature 518.7540 (2015), pp. 529–533. issn: 1476-4687. doi: 10 . 1038 /
nature14236. URL: [https://www.nature.com/articles/nature14236](https://www.nature.com/articles/nature14236)
flatland-sparse-global-conv-ppo:
run: PPO
env: flatland_sparse
stop:
timesteps_total: 10000000 # 1e7
checkpoint_freq: 10
checkpoint_at_end: True
keep_checkpoints_num: 5
checkpoint_score_attr: episode_reward_mean
config:
clip_rewards: True
clip_param: 0.1
vf_clip_param: 500.0
entropy_coeff: 0.01
# effective batch_size: train_batch_size * num_agents_in_each_environment [5, 10]
# see https://github.com/ray-project/ray/issues/4628
train_batch_size: 1000 # 5000
rollout_fragment_length: 50 # 100
sgd_minibatch_size: 100 # 500
num_sgd_iter: 10
num_workers: 7
num_envs_per_worker: 5
batch_mode: truncate_episodes
observation_filter: NoFilter
vf_share_layers: True
vf_loss_coeff: 0.5
num_gpus: 1
env_config:
observation: global
observation_config:
max_width: 32
max_height: 32
width: 32
height: 32
number_of_agents: 10
max_num_cities: 3
grid_mode: False
max_rails_between_cities: 2
max_rails_in_city: 3
malfunction_rate: 8000
malfunction_min_duration: 15
malfunction_max_duration: 50
# speed ratio map: keys must be strings but will be converted to float later!
speed_ratio_map:
'1.0': 0.25
'0.5': 0.25
'0.3333333': 0.25
'0.25': 0.25
seed: 0
regenerate_rail_on_reset: True
regenerate_schedule_on_reset: True
render: False
wandb:
project: flatland
entity: masterscrat
tags: ["32x32", "global_obs"]
model:
custom_model: global_obs_model
custom_options:
architecture: impala
architecture_options:
residual_layers: [[16, 2], [32, 4]]
......@@ -11,7 +11,7 @@ class GlobalObsModel(TFModelV2):
super().__init__(obs_space, action_space, num_outputs, model_config, name)
self._options = model_config['custom_options']
self._model = GlobalObsModule(action_space=action_space, architecture=self._options['architecture'],
name="global_obs_model")
name="global_obs_model", **self._options['architecture_options'])
def forward(self, input_dict, state, seq_lens):
obs = preprocess_obs(input_dict['obs'])
......@@ -45,15 +45,15 @@ def preprocess_obs(obs) -> tf.Tensor:
class GlobalObsModule(tf.Module):
def __init__(self, action_space, architecture: str, name=None):
def __init__(self, action_space, architecture: str, name=None, **kwargs):
super().__init__(name=name)
assert isinstance(action_space, gym.spaces.Discrete), \
"Currently, only 'gym.spaces.Discrete' action spaces are supported."
with self.name_scope:
if architecture == 'nature':
self._cnn = NatureCNN(activation_out=True)
self._cnn = NatureCNN(activation_out=True, **kwargs)
elif architecture == 'impala':
self._cnn = ImpalaCNN(activation_out=True)
self._cnn = ImpalaCNN(activation_out=True, **kwargs)
else:
raise ValueError(f"Invalid architecture: {architecture}.")
self._logits_layer = tf.keras.layers.Dense(units=action_space.n)
......
......@@ -9,6 +9,7 @@ from ray.cluster_utils import Cluster
from ray.rllib.evaluation import MultiAgentEpisode
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.tune import tune, run_experiments
from ray.tune.logger import TBXLogger
from ray.tune.resources import resources_to_json
from ray.tune.tune import _make_scheduler
......@@ -120,7 +121,7 @@ def run(args, parser):
exp['config']['callbacks'] = {
'on_episode_end': on_episode_end,
}
exp['loggers'] = [WandbLogger]
exp['loggers'] = [WandbLogger, TBXLogger]
if args.ray_num_nodes:
cluster = Cluster()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment