global_obs_model.py 3.23 KB
Newer Older
metataro's avatar
metataro committed
1
import gym
2
import numpy as np
metataro's avatar
metataro committed
3
4
5
6
7
8
9
10
11
12
import tensorflow as tf
from flatland.core.grid import grid4
from ray.rllib.models.tf.tf_modelv2 import TFModelV2

from models.common.models import NatureCNN, ImpalaCNN


class GlobalObsModel(TFModelV2):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        super().__init__(obs_space, action_space, num_outputs, model_config, name)
13
14
15
        assert isinstance(action_space, gym.spaces.Discrete), \
            "Currently, only 'gym.spaces.Discrete' action spaces are supported."
        self._action_space = action_space
metataro's avatar
metataro committed
16
        self._options = model_config['custom_options']
17
18
19
        self._mask_unavailable_actions = self._options.get("mask_unavailable_actions", False)

        if self._mask_unavailable_actions:
nilabha's avatar
nilabha committed
20
            obs_space = obs_space['obs']
21
        else:
nilabha's avatar
nilabha committed
22
            obs_space = obs_space
23

nilabha's avatar
nilabha committed
24
25
        observations = tf.keras.layers.Input(shape=obs_space.shape)
        processed_observations = observations # preprocess_obs(tuple(observations))
26
27
28
29
30
31
32
33
34
35
36

        if self._options['architecture'] == 'nature':
            conv_out = NatureCNN(activation_out=True, **self._options['architecture_options'])(processed_observations)
        elif self._options['architecture'] == 'impala':
            conv_out = ImpalaCNN(activation_out=True, **self._options['architecture_options'])(processed_observations)
        else:
            raise ValueError(f"Invalid architecture: {self._options['architecture']}.")
        logits = tf.keras.layers.Dense(units=action_space.n)(conv_out)
        baseline = tf.keras.layers.Dense(units=1)(conv_out)
        self._model = tf.keras.Model(inputs=observations, outputs=[logits, baseline])
        self.register_variables(self._model.variables)
nilabha's avatar
nilabha committed
37
        # self._model.summary()
metataro's avatar
metataro committed
38
39

    def forward(self, input_dict, state, seq_lens):
40
41
42
43
44
        # obs = preprocess_obs(input_dict['obs'])
        if self._mask_unavailable_actions:
            obs = input_dict['obs']['obs']
        else:
            obs = input_dict['obs']
metataro's avatar
metataro committed
45
46
        logits, baseline = self._model(obs)
        self.baseline = tf.reshape(baseline, [-1])
47
48
49
        if self._mask_unavailable_actions:
            inf_mask = tf.maximum(tf.math.log(input_dict['obs']['available_actions']), tf.float32.min)
            logits = logits + inf_mask
metataro's avatar
metataro committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        return logits, state

    def variables(self):
        return self._model.variables

    def value_function(self):
        return self.baseline


def preprocess_obs(obs) -> tf.Tensor:
    transition_map, agents_state, targets = obs

    processed_agents_state_layers = []
    for i, feature_layer in enumerate(tf.unstack(agents_state, axis=-1)):
        if i in {0, 1}:  # agent direction (categorical)
            feature_layer = tf.one_hot(tf.cast(feature_layer, tf.int32), depth=len(grid4.Grid4TransitionsEnum) + 1,
                                       dtype=tf.float32)
        elif i in {2, 4}:  # counts
            feature_layer = tf.expand_dims(tf.math.log(feature_layer + 1), axis=-1)
        else:  # well behaved scalars
            feature_layer = tf.expand_dims(feature_layer, axis=-1)
        processed_agents_state_layers.append(feature_layer)

    return tf.concat(
        [tf.cast(transition_map, tf.float32), tf.cast(targets, tf.float32)] + processed_agents_state_layers, axis=-1)