env_observation_builder.py 2.76 KB
Newer Older
1 2 3 4
"""
ObservationBuilder objects are objects that can be passed to environments designed for customizability.
The ObservationBuilder-derived custom classes implement 2 functions, reset() and get() or get(handle).

u214892's avatar
u214892 committed
5 6 7 8
+ `reset()` is called after each environment reset, to allow for pre-computing relevant data.

+ `get()` is called whenever an observation has to be computed, potentially for each agent independently in case of \
multi-agent environments.
9 10

"""
11 12
from typing import Optional, List

u214892's avatar
u214892 committed
13
import numpy as np
14

15 16
from flatland.core.env import Environment

17 18

class ObservationBuilder:
19 20 21
    """
    ObservationBuilder base class.
    """
Erik Nygren's avatar
Erik Nygren committed
22

23
    def __init__(self):
u229589's avatar
u229589 committed
24
        self.env = None
25

u229589's avatar
u229589 committed
26
    def set_env(self, env: Environment):
u214892's avatar
u214892 committed
27
        self.env: Environment = env
28 29

    def reset(self):
30 31 32
        """
        Called after each environment reset.
        """
33 34
        raise NotImplementedError()

35
    def get_many(self, handles: Optional[List[int]] = None):
36
        """
u214892's avatar
u214892 committed
37 38
        Called whenever an observation has to be computed for the `env` environment, for each agent with handle
        in the `handles` list.
39 40

        Parameters
u214892's avatar
u214892 committed
41 42
        ----------
        handles : list of handles, optional
43 44 45 46 47 48
            List with the handles of the agents for which to compute the observation vector.

        Returns
        -------
        function
            A dictionary of observation structures, specific to the corresponding environment, with handles from
u214892's avatar
u214892 committed
49
            `handles` as keys.
50 51
        """
        observations = {}
52 53
        if handles is None:
            handles = []
54
        for h in handles:
u214892's avatar
u214892 committed
55
            observations[h] = self.get(h)
56 57
        return observations

58
    def get(self, handle: int = 0):
59
        """
u214892's avatar
u214892 committed
60 61
        Called whenever an observation has to be computed for the `env` environment, possibly
        for each agent independently (agent id `handle`).
62 63

        Parameters
u214892's avatar
u214892 committed
64 65
        ----------
        handle : int, optional
66 67 68 69 70 71 72
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
        function
            An observation structure, specific to the corresponding environment.
        """
73
        raise NotImplementedError()
u214892's avatar
u214892 committed
74 75 76 77 78 79

    def _get_one_hot_for_agent_direction(self, agent):
        """Retuns the agent's direction to one-hot encoding."""
        direction = np.zeros(4)
        direction[agent.direction] = 1
        return direction
80

81

82 83 84 85 86 87 88
class DummyObservationBuilder(ObservationBuilder):
    """
    DummyObservationBuilder class which returns dummy observations
    This is used in the evaluation service
    """

    def __init__(self):
89
        super().__init__()
90 91 92 93

    def reset(self):
        pass

94
    def get_many(self, handles: Optional[List[int]] = None) -> bool:
95 96
        return True

97
    def get(self, handle: int = 0) -> bool:
98
        return True