Commit 673f3e52 authored by metataro's avatar metataro

action masking, skip "no choice cells", fixed global obs for apex

parent 43a6537c
......@@ -12,35 +12,35 @@ class StepOutput(NamedTuple):
info: Dict[int, Dict[str, Any]]
class FlatlandRllibWrapper(object):
def __init__(self, rail_env: RailEnv, render: bool = False, regenerate_rail_on_reset: bool = True,
class FlatlandGymEnv(gym.Env):
def __init__(self,
rail_env: RailEnv,
observation_space: gym.spaces.Space,
render: bool = False,
regenerate_rail_on_reset: bool = True,
regenerate_schedule_on_reset: bool = True) -> None:
super().__init__()
self._env = rail_env
self._agents_done = []
self._agent_scores = defaultdict(float)
self._agent_steps = defaultdict(int)
self._regenerate_rail_on_reset = regenerate_rail_on_reset
self._regenerate_schedule_on_reset = regenerate_schedule_on_reset
self._action_space = gym.spaces.Discrete(5)
self.rail_env = rail_env
self.action_space = gym.spaces.Discrete(5)
self.observation_space = observation_space
if render:
from flatland.utils.rendertools import RenderTool
self.renderer = RenderTool(self._env, gl="PILSVG")
self.renderer = RenderTool(self.rail_env, gl="PILSVG")
else:
self.renderer = None
@property
def action_space(self) -> gym.spaces.Discrete:
return self._action_space
def step(self, action_dict: Dict[int, RailEnvActions]) -> StepOutput:
d, r, o = None, None, None
obs_or_done = False
while not obs_or_done:
# Perform env steps as long as there is no observation (for all agents) or all agents are done
# The observation is `None` if an agent is done or malfunctioning.
obs, rewards, dones, infos = self._env.step(action_dict)
obs, rewards, dones, infos = self.rail_env.step(action_dict)
if self.renderer is not None:
self.renderer.render_env(show=True, show_predictions=True, show_observations=False)
......@@ -57,7 +57,6 @@ class FlatlandRllibWrapper(object):
if agent != '__all__':
if done:
self._agents_done.append(agent)
# if infos['action_required'][agent] or done:
o[agent] = obs[agent]
r[agent] = rewards[agent]
self._agent_scores[agent] += rewards[agent]
......@@ -70,9 +69,9 @@ class FlatlandRllibWrapper(object):
assert all([x is not None for x in (d, r, o)])
return StepOutput(obs=o, reward=r, done=d, info={agent: {
'max_episode_steps': self._env._max_episode_steps,
'num_agents': self._env.get_num_agents(),
'agent_done': d[agent] and agent not in self._env.active_agents,
'max_episode_steps': self.rail_env._max_episode_steps,
'num_agents': self.rail_env.get_num_agents(),
'agent_done': d[agent] and agent not in self.rail_env.active_agents,
'agent_score': self._agent_scores[agent],
'agent_step': self._agent_steps[agent],
} for agent in o.keys()})
......@@ -81,9 +80,12 @@ class FlatlandRllibWrapper(object):
self._agents_done = []
self._agent_scores = defaultdict(float)
self._agent_steps = defaultdict(int)
obs, infos = self._env.reset(regenerate_rail=self._regenerate_rail_on_reset,
regenerate_schedule=self._regenerate_schedule_on_reset,
random_seed=random_seed)
obs, infos = self.rail_env.reset(regenerate_rail=self._regenerate_rail_on_reset,
regenerate_schedule=self._regenerate_schedule_on_reset,
random_seed=random_seed)
if self.renderer is not None:
self.renderer.reset()
return {k: o for k, o in obs.items() if not k == '__all__'}
def render(self, mode='human'):
raise NotImplementedError
from typing import Dict, Any, Optional, Set, List
import gym
import numpy as np
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.rail_env import RailEnv, RailEnvActions
from envs.flatland.utils.gym_env import StepOutput
def available_actions(env: RailEnv, agent: EnvAgent, allow_noop=True) -> List[int]:
if agent.position is None:
return [1] * len(RailEnvActions)
else:
possible_transitions = env.rail.get_transitions(*agent.position, agent.direction)
# some actions are always available:
available_acts = [0] * len(RailEnvActions)
available_acts[RailEnvActions.MOVE_FORWARD] = 1
available_acts[RailEnvActions.STOP_MOVING] = 1
if allow_noop:
available_acts[RailEnvActions.DO_NOTHING] = 1
# check if turn left/right are available:
for movement in range(4):
if possible_transitions[movement]:
if movement == (agent.direction + 1) % 4:
available_acts[RailEnvActions.MOVE_RIGHT] = 1
elif movement == (agent.direction - 1) % 4:
available_acts[RailEnvActions.MOVE_LEFT] = 1
return available_acts
class AvailableActionsWrapper(gym.Wrapper):
def __init__(self, env, allow_noop=True) -> None:
super().__init__(env)
self._allow_noop = allow_noop
self.observation_space = gym.spaces.Dict({
'obs': self.env.observation_space,
'available_actions': gym.spaces.Box(low=0, high=1, shape=(self.action_space.n,), dtype=np.int32)
})
def step(self, action_dict: Dict[int, RailEnvActions]) -> StepOutput:
obs, reward, done, info = self.env.step(action_dict)
return StepOutput(self._transform_obs(obs), reward, done, info)
def reset(self, random_seed: Optional[int] = None) -> Dict[int, Any]:
return self._transform_obs(self.env.reset(random_seed))
def _transform_obs(self, obs):
rail_env = self.unwrapped.rail_env
return {
agent_id: {
'obs': agent_obs,
'available_actions': np.asarray(available_actions(rail_env, rail_env.agents[agent_id], self._allow_noop))
} for agent_id, agent_obs in obs.items()
}
def find_all_cells_where_agent_can_choose(rail_env: RailEnv):
switches = []
switches_neighbors = []
directions = list(range(4))
for h in range(rail_env.height):
for w in range(rail_env.width):
pos = (w, h)
is_switch = False
# Check for switch: if there is more than one outgoing transition
for orientation in directions:
possible_transitions = rail_env.rail.get_transitions(*pos, orientation)
num_transitions = np.count_nonzero(possible_transitions)
if num_transitions > 1:
switches.append(pos)
is_switch = True
break
if is_switch:
# Add all neighbouring rails, if pos is a switch
for orientation in directions:
possible_transitions = rail_env.rail.get_transitions(*pos, orientation)
for movement in directions:
if possible_transitions[movement]:
switches_neighbors.append(get_new_position(pos, movement))
decision_cells = switches + switches_neighbors
return tuple(map(set, (switches, switches_neighbors, decision_cells)))
class SkipNoChoiceCellsWrapper(gym.Wrapper):
def __init__(self, env) -> None:
super().__init__(env)
self._switches = None
self._switches_neighbors = None
self._decision_cells = None
def _on_decision_cell(self, agent: EnvAgent):
return agent.position is None or agent.position in self._decision_cells
def _on_switch(self, agent: EnvAgent):
return agent.position in self._switches
def _next_to_switch(self, agent: EnvAgent):
return agent.position in self._switches_neighbors
def step(self, action_dict: Dict[int, RailEnvActions]) -> StepOutput:
o, r, d, i = {}, {}, {}, {}
while len(o) == 0:
obs, reward, done, info = self.env.step(action_dict)
for agent_id, agent_obs in obs.items():
if done[agent_id] or self._on_decision_cell(self.unwrapped.rail_env.agents[agent_id]):
o[agent_id] = agent_obs
r[agent_id] = reward[agent_id]
d[agent_id] = done[agent_id]
i[agent_id] = info[agent_id]
d['__all__'] = done['__all__']
action_dict = {}
return StepOutput(o, r, d, i)
def reset(self, random_seed: Optional[int] = None) -> Dict[int, Any]:
obs = self.env.reset(random_seed)
self._switches, self._switches_neighbors, self._decision_cells = \
find_all_cells_where_agent_can_choose(self.unwrapped.rail_env)
return obs
......@@ -5,7 +5,9 @@ from ray.rllib import MultiAgentEnv
from envs.flatland.utils.env_generators import random_sparse_env_small
from envs.flatland.observations import make_obs
from envs.flatland.utils.rllib_wrapper import FlatlandRllibWrapper
from envs.flatland.utils.gym_env import FlatlandGymEnv
from envs.flatland.utils.gym_env_wrappers import SkipNoChoiceCellsWrapper, AvailableActionsWrapper
class FlatlandRandomSparseSmall(MultiAgentEnv):
......@@ -22,9 +24,17 @@ class FlatlandRandomSparseSmall(MultiAgentEnv):
self._next_test_seed = self._min_test_seed
self._num_resets = 0
self._observation = make_obs(env_config['observation'], env_config.get('observation_config'))
self._env = FlatlandRllibWrapper(rail_env=self._launch(), render=env_config['render'],
regenerate_rail_on_reset=env_config['regenerate_rail_on_reset'],
regenerate_schedule_on_reset=env_config['regenerate_schedule_on_reset'])
self._env = FlatlandGymEnv(
rail_env=self._launch(),
observation_space=self._observation.observation_space(),
# render=env_config['render'], # TODO need to fix gl compatibility first
regenerate_rail_on_reset=self._config['regenerate_rail_on_reset'],
regenerate_schedule_on_reset=self._config['regenerate_schedule_on_reset']
)
if env_config.get('skip_no_choice_cells', False):
self._env = SkipNoChoiceCellsWrapper(self._env)
if env_config.get('available_actions_obs', False):
self._env = AvailableActionsWrapper(self._env)
@property
def observation_space(self) -> gym.spaces.Space:
......
......@@ -9,7 +9,9 @@ from flatland.envs.schedule_generators import sparse_schedule_generator
from envs.flatland import get_generator_config
from envs.flatland.observations import make_obs
from envs.flatland.utils.rllib_wrapper import FlatlandRllibWrapper, StepOutput
from envs.flatland.utils.gym_env import FlatlandGymEnv, StepOutput
from envs.flatland.utils.gym_env_wrappers import SkipNoChoiceCellsWrapper, AvailableActionsWrapper
class FlatlandSingle(gym.Env):
......@@ -19,12 +21,16 @@ class FlatlandSingle(gym.Env):
def __init__(self, env_config):
self._observation = make_obs(env_config['observation'], env_config.get('observation_config'))
self._config = get_generator_config(env_config['generator_config'])
self._env = FlatlandRllibWrapper(
self._env = FlatlandGymEnv(
rail_env=self._launch(),
observation_space=self._observation.observation_space(),
regenerate_rail_on_reset=self._config['regenerate_rail_on_reset'],
regenerate_schedule_on_reset=self._config['regenerate_schedule_on_reset']
)
if env_config.get('skip_no_choice_cells', False):
self._env = SkipNoChoiceCellsWrapper(self._env)
if env_config.get('available_actions_obs', False):
self._env = AvailableActionsWrapper(self._env)
def _launch(self):
rail_generator = sparse_rail_generator(
......
......@@ -10,7 +10,8 @@ from ray.rllib import MultiAgentEnv
from envs.flatland import get_generator_config
from envs.flatland.observations import make_obs
from envs.flatland.utils.rllib_wrapper import FlatlandRllibWrapper
from envs.flatland.utils.gym_env import FlatlandGymEnv
from envs.flatland.utils.gym_env_wrappers import AvailableActionsWrapper, SkipNoChoiceCellsWrapper
class FlatlandSparse(MultiAgentEnv):
......@@ -28,16 +29,22 @@ class FlatlandSparse(MultiAgentEnv):
pprint(self._config)
print("=" * 50)
self._env = FlatlandRllibWrapper(
self._env = FlatlandGymEnv(
rail_env=self._launch(),
observation_space=self._observation.observation_space(),
# render=env_config['render'], # TODO need to fix gl compatibility first
regenerate_rail_on_reset=self._config['regenerate_rail_on_reset'],
regenerate_schedule_on_reset=self._config['regenerate_schedule_on_reset']
)
if env_config.get('skip_no_choice_cells', False):
self._env = SkipNoChoiceCellsWrapper(self._env)
if env_config.get('available_actions_obs', False):
self._env = AvailableActionsWrapper(self._env)
@property
def observation_space(self) -> gym.spaces.Space:
return self._observation.observation_space()
print(self._env.observation_space)
return self._env.observation_space
@property
def action_space(self) -> gym.spaces.Space:
......
flatland-sparse-small-action-mask-tree-fc-apex:
run: APEX
env: flatland_sparse
stop:
timesteps_total: 100000000 # 1e8
checkpoint_freq: 10
checkpoint_at_end: True
# keep_checkpoints_num: 5
checkpoint_score_attr: episode_reward_mean
config:
num_workers: 15
num_envs_per_worker: 5
num_gpus: 1
hiddens: []
dueling: False
env_config:
skip_no_choice_cells: True
available_actions_obs: True
observation: new_tree
observation_config:
max_depth: 2
shortest_path_max_depth: 30
generator: sparse_rail_generator
generator_config: small_v0
wandb:
project: flatland
entity: masterscrat
tags: ["small_v0", "new_tree_obs", "apex", "skip_no_choice_cells",
"action_mask"] # TODO should be set programmatically
model:
custom_model: fully_connected_model
custom_options:
layers: [256, 256]
activation: relu
layer_norm: False
vf_share_layers: True # False
mask_unavailable_actions: True
flatland-sparse-small-action-mask-tree-fc-ppo:
run: PPO
env: flatland_sparse
stop:
timesteps_total: 10000000 # 1e7
checkpoint_freq: 10
checkpoint_at_end: True
checkpoint_score_attr: episode_reward_mean
config:
clip_rewards: False
# clip_param: 0.1
vf_clip_param: 500.0
entropy_coeff: 0.01
# effective batch_size: train_batch_size * num_agents_in_each_environment
# see https://github.com/ray-project/ray/issues/4628
train_batch_size: 1000 # 5000
rollout_fragment_length: 50 # 100
sgd_minibatch_size: 100 # 500
num_sgd_iter: 10
num_workers: 15
num_envs_per_worker: 5
batch_mode: truncate_episodes
observation_filter: NoFilter
vf_share_layers: True
vf_loss_coeff: 0.05
num_gpus: 0
env_config:
skip_no_choice_cells: True
available_actions_obs: True
observation: new_tree
observation_config:
max_depth: 2
shortest_path_max_depth: 30
generator: sparse_rail_generator
generator_config: small_v0
wandb:
project: flatland
entity: masterscrat
tags: ["small_v0", "new_tree_obs", "ppo", "skip_no_choice_cells",
"action_mask"] # TODO should be set programmatically
model:
custom_model: fully_connected_model
custom_options:
layers: [256, 256]
activation: relu
layer_norm: False
vf_share_layers: True # False
mask_unavailable_actions: True
......@@ -12,6 +12,9 @@ flatland-random-sparse-small-tree-fc-ppo:
num_envs_per_worker: 5
num_gpus: 1
hiddens: []
dueling: False
env_config:
observation: global
observation_config:
......@@ -24,11 +27,11 @@ flatland-random-sparse-small-tree-fc-ppo:
wandb:
project: flatland
entity: masterscrat
tags: ["small_v0", "global_obs", "apex"] # TODO should be set programmatically
tags: ["small_v0", "global_obs", "apex"] # TODO should be set programmatically
model:
custom_model: global_obs_model
custom_options:
architecture: impala
architecture_options:
residual_layers: [[16, 2], [32, 4]]
custom_options:
architecture: impala
architecture_options:
residual_layers: [[16, 2], [32, 4]]
......@@ -4,9 +4,10 @@ import tensorflow as tf
class FullyConnected(tf.Module):
def __init__(self, layers: List[int] = None, activation=tf.tanh, layer_norm=False, activation_out=True, name=None):
def __init__(self, layers: List[int] = None, activation=tf.tanh, layer_norm=False, activation_out=True,
name="fully_connected_net"):
super().__init__(name)
self.layers = []
self.layers = [tf.keras.layers.Flatten()]
with self.name_scope:
for i, num_hidden in enumerate(layers):
self.layers.append(tf.keras.layers.Dense(units=num_hidden))
......@@ -24,7 +25,7 @@ class FullyConnected(tf.Module):
class NatureCNN(tf.Module):
def __init__(self, activation_out=True, name=None):
def __init__(self, activation_out=True, name="nature_cnn_net"):
super().__init__(name)
with self.name_scope:
self.layers = [
......@@ -55,7 +56,7 @@ class ImpalaResidualLayer(NamedTuple):
class ImpalaCNN(tf.Module):
def __init__(self, residual_layers: Iterable[Tuple[int, int]] = None, activation_out=True, name=None):
def __init__(self, residual_layers: Iterable[Tuple[int, int]] = None, activation_out=True, name="impala_cnn_net"):
super().__init__(name)
residual_layers = residual_layers or [(16, 2), (32, 2), (32, 2)]
self._residual_layers = []
......
import sys
import gym
import tensorflow as tf
import numpy as np
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from models.common.models import FullyConnected
class FullyConnectedModel(TFModelV2):
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
super().__init__(obs_space, action_space, num_outputs, model_config, name)
assert isinstance(action_space, gym.spaces.Discrete), \
"Currently, only 'gym.spaces.Discrete' action spaces are supported."
self._action_space = action_space
self._options = model_config['custom_options']
self._mask_unavailable_actions = self._options.get("mask_unavailable_actions", False)
if self._mask_unavailable_actions:
observations = tf.keras.layers.Input(shape=obs_space.original_space['obs'].shape)
else:
observations = tf.keras.layers.Input(shape=obs_space.shape)
activation = tf.keras.activations.deserialize(self._options['activation'])
fc_out = FullyConnected(layers=self._options['layers'], activation=activation,
layer_norm=self._options['layer_norm'], activation_out=True)(observations)
logits = tf.keras.layers.Dense(units=action_space.n)(fc_out)
baseline = tf.keras.layers.Dense(units=1)(fc_out)
self._model = tf.keras.Model(inputs=[observations], outputs=[logits, baseline])
self.register_variables(self._model.variables)
self._model.summary()
def forward(self, input_dict, state, seq_lens):
if self._mask_unavailable_actions:
obs = input_dict['obs']['obs']
else:
obs = input_dict['obs']
logits, baseline = self._model(obs)
self.baseline = tf.reshape(baseline, [-1])
if self._mask_unavailable_actions:
inf_mask = tf.maximum(tf.math.log(input_dict['obs']['available_actions']), tf.float32.min)
logits = logits + inf_mask
return logits, state
def value_function(self):
return self.baseline
import gym
import numpy as np
import tensorflow as tf
from flatland.core.grid import grid4
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
......@@ -9,14 +10,43 @@ from models.common.models import NatureCNN, ImpalaCNN
class GlobalObsModel(TFModelV2):
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
super().__init__(obs_space, action_space, num_outputs, model_config, name)
assert isinstance(action_space, gym.spaces.Discrete), \
"Currently, only 'gym.spaces.Discrete' action spaces are supported."
self._action_space = action_space
self._options = model_config['custom_options']
self._model = GlobalObsModule(action_space=action_space, architecture=self._options['architecture'],
name="global_obs_model", **self._options['architecture_options'])
self._mask_unavailable_actions = self._options.get("mask_unavailable_actions", False)
if self._mask_unavailable_actions:
obs_space = obs_space.original_space['obs']
else:
obs_space = obs_space.original_space
observations = [tf.keras.layers.Input(shape=o.shape) for o in obs_space]
processed_observations = preprocess_obs(tuple(observations))
if self._options['architecture'] == 'nature':
conv_out = NatureCNN(activation_out=True, **self._options['architecture_options'])(processed_observations)
elif self._options['architecture'] == 'impala':
conv_out = ImpalaCNN(activation_out=True, **self._options['architecture_options'])(processed_observations)
else:
raise ValueError(f"Invalid architecture: {self._options['architecture']}.")
logits = tf.keras.layers.Dense(units=action_space.n)(conv_out)
baseline = tf.keras.layers.Dense(units=1)(conv_out)
self._model = tf.keras.Model(inputs=observations, outputs=[logits, baseline])
self.register_variables(self._model.variables)
self._model.summary()
def forward(self, input_dict, state, seq_lens):
obs = preprocess_obs(input_dict['obs'])
# obs = preprocess_obs(input_dict['obs'])
if self._mask_unavailable_actions:
obs = input_dict['obs']['obs']
else:
obs = input_dict['obs']
logits, baseline = self._model(obs)
self.baseline = tf.reshape(baseline, [-1])
if self._mask_unavailable_actions:
inf_mask = tf.maximum(tf.math.log(input_dict['obs']['available_actions']), tf.float32.min)
logits = logits + inf_mask
return logits, state
def variables(self):
......@@ -42,25 +72,3 @@ def preprocess_obs(obs) -> tf.Tensor:
return tf.concat(
[tf.cast(transition_map, tf.float32), tf.cast(targets, tf.float32)] + processed_agents_state_layers, axis=-1)
class GlobalObsModule(tf.Module):
def __init__(self, action_space, architecture: str, name=None, **kwargs):
super().__init__(name=name)
assert isinstance(action_space, gym.spaces.Discrete), \
"Currently, only 'gym.spaces.Discrete' action spaces are supported."
with self.name_scope:
if architecture == 'nature':
self._cnn = NatureCNN(activation_out=True, **kwargs)
elif architecture == 'impala':
self._cnn = ImpalaCNN(activation_out=True, **kwargs)
else:
raise ValueError(f"Invalid architecture: {architecture}.")
self._logits_layer = tf.keras.layers.Dense(units=action_space.n)
self._baseline_layer = tf.keras.layers.Dense(units=1)
def __call__(self, spatial_obs, non_spatial_obs=None):
latent_repr = self._cnn(spatial_obs)
logits = self._logits_layer(latent_repr)
baseline = self._baseline_layer(latent_repr)
return logits, baseline
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment