Commit f4d4d567 authored by manuschn's avatar manuschn
Browse files

global observation based on projected agent density

parent 872cddb8
import gym
import numpy as np
from typing import Optional, List, Dict
from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import RailEnv
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from envs.flatland.observations import Observation, register_obs
class ProjectedDensityObservation(Observation):
def __init__(self, config) -> None:
self._builder = ProjectedDensityForRailEnv(config['height'], config['width'], config['encoding'], config['max_t'])
def builder(self) -> ObservationBuilder:
return self._builder
def observation_space(self) -> gym.Space:
obs_shape = self._builder.observation_shape
return gym.spaces.Tuple([
gym.spaces.Box(low=0, high=1, shape=obs_shape, dtype=np.float32),
gym.spaces.Box(low=0, high=1, shape=obs_shape, dtype=np.float32),
class ProjectedDensityForRailEnv(ObservationBuilder):
def __init__(self, height, width, encoding='exp_decay', max_t=10):
self._height = height
self._width = width
self._depth = max_t + 1 if encoding == 'series' else 1
if encoding == 'exp_decay':
self._encode = lambda t: np.exp(-t / np.sqrt(max_t))
elif encoding == 'lin_decay':
self._encode = lambda t: (max_t - t) / max_t
self._encode = lambda t: 1
self._predictor = ShortestPathPredictorForRailEnv(max_t)
def observation_shape(self):
return (self._height, self._width, self._depth)
def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, np.ndarray]:
get density maps for agents and compose the observation with agent's and other's density maps
self._predictions = self._predictor.get()
density_maps = dict()
for handle in handles:
density_maps[handle] = self.get(handle)
obs = dict()
for handle in handles:
other_dens_maps = [density_maps[key] for key in density_maps if key != handle]
others_density = np.mean(np.array(other_dens_maps), axis=0)
obs[handle] = [density_maps[handle], others_density]
return obs
def get(self, handle: int = 0):
compute density map for agent: a value is asigned to every cell along the shortest path between
the agent and its target based on the distance to the agent, i.e. the number of time steps the
agent needs to reach the cell, encoding the time information.
density_map = np.zeros(shape=(self._height, self._width, self._depth), dtype=np.float32)
agent = self.env.agents[handle]
if self._predictions[handle] is not None:
for t, prediction in enumerate(self._predictions[handle]):
p = tuple(np.array(prediction[1:3]).astype(int))
d = t if self._depth > 1 else 0
density_map[p][d] = self._encode(t)
return density_map
def set_env(self, env: Environment):
self.env: RailEnv = env
def reset(self):
run: APEX
env: flatland_sparse
timesteps_total: 15000000 # 1e8
checkpoint_freq: 10
checkpoint_at_end: True
keep_checkpoints_num: 5
checkpoint_score_attr: episode_reward_mean
num_workers: 15
num_envs_per_worker: 5
num_gpus: 0
hiddens: []
dueling: False
observation: density
width: 25
height: 25
max_t: 100
encoding: exp_decay
generator: sparse_rail_generator
generator_config: small_v0
resolve_deadlocks: False
deadlock_reward: 0
density_reward_factor: 0
project: flatland
entity: masterscrat
tags: ["small_v0", "global_dens_obs_conv", "apex"] # TODO should be set programmatically
custom_model: global_dens_obs_model
architecture: impala
residual_layers: [[16, 2], [32, 4]]
import gym
import numpy as np
import tensorflow as tf
from flatland.core.grid import grid4
from import TFModelV2
from models.common.models import NatureCNN, ImpalaCNN
class GlobalDensObsModel(TFModelV2):
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
super().__init__(obs_space, action_space, num_outputs, model_config, name)
assert isinstance(action_space, gym.spaces.Discrete), \
"Currently, only 'gym.spaces.Discrete' action spaces are supported."
self._action_space = action_space
self._options = model_config['custom_options']
self._mask_unavailable_actions = self._options.get("mask_unavailable_actions", False)
if self._mask_unavailable_actions:
obs_space = obs_space.original_space['obs']
obs_space = obs_space.original_space
observations = [tf.keras.layers.Input(shape=o.shape) for o in obs_space]
comp_obs = tf.concat(observations, axis=-1)
if self._options['architecture'] == 'nature':
conv_out = NatureCNN(activation_out=True, **self._options['architecture_options'])(comp_obs)
elif self._options['architecture'] == 'impala':
conv_out = ImpalaCNN(activation_out=True, **self._options['architecture_options'])(comp_obs)
raise ValueError(f"Invalid architecture: {self._options['architecture']}.")
logits = tf.keras.layers.Dense(units=action_space.n)(conv_out)
baseline = tf.keras.layers.Dense(units=1)(conv_out)
self._model = tf.keras.Model(inputs=observations, outputs=[logits, baseline])
def forward(self, input_dict, state, seq_lens):
if self._mask_unavailable_actions:
obs = input_dict['obs']['obs']
obs = input_dict['obs']
logits, baseline = self._model(obs)
self.baseline = tf.reshape(baseline, [-1])
if self._mask_unavailable_actions:
inf_mask = tf.maximum(tf.math.log(input_dict['obs']['available_actions']), tf.float32.min)
logits = logits + inf_mask
return logits, state
def variables(self):
return self._model.variables
def value_function(self):
return self.baseline
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment