env_observation_builder.py 2.76 KB
Newer Older
1
2
3
4
"""
ObservationBuilder objects are objects that can be passed to environments designed for customizability.
The ObservationBuilder-derived custom classes implement 2 functions, reset() and get() or get(handle).

u214892's avatar
u214892 committed
5
6
7
8
+ `reset()` is called after each environment reset, to allow for pre-computing relevant data.

+ `get()` is called whenever an observation has to be computed, potentially for each agent independently in case of \
multi-agent environments.
9
10

"""
11
12
from typing import Optional, List

u214892's avatar
u214892 committed
13
import numpy as np
14

15
16
from flatland.core.env import Environment

17
18

class ObservationBuilder:
19
20
21
    """
    ObservationBuilder base class.
    """
Erik Nygren's avatar
Erik Nygren committed
22

23
    def __init__(self):
u229589's avatar
u229589 committed
24
        self.env = None
25

u229589's avatar
u229589 committed
26
    def set_env(self, env: Environment):
u214892's avatar
u214892 committed
27
        self.env: Environment = env
28
29

    def reset(self):
30
31
32
        """
        Called after each environment reset.
        """
33
34
        raise NotImplementedError()

35
    def get_many(self, handles: Optional[List[int]] = None):
36
        """
u214892's avatar
u214892 committed
37
38
        Called whenever an observation has to be computed for the `env` environment, for each agent with handle
        in the `handles` list.
39
40

        Parameters
u214892's avatar
u214892 committed
41
42
        ----------
        handles : list of handles, optional
43
44
45
46
47
48
            List with the handles of the agents for which to compute the observation vector.

        Returns
        -------
        function
            A dictionary of observation structures, specific to the corresponding environment, with handles from
u214892's avatar
u214892 committed
49
            `handles` as keys.
50
51
        """
        observations = {}
52
53
        if handles is None:
            handles = []
54
        for h in handles:
u214892's avatar
u214892 committed
55
            observations[h] = self.get(h)
56
57
        return observations

58
    def get(self, handle: int = 0):
59
        """
u214892's avatar
u214892 committed
60
61
        Called whenever an observation has to be computed for the `env` environment, possibly
        for each agent independently (agent id `handle`).
62
63

        Parameters
u214892's avatar
u214892 committed
64
65
        ----------
        handle : int, optional
66
67
68
69
70
71
72
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
        function
            An observation structure, specific to the corresponding environment.
        """
73
        raise NotImplementedError()
u214892's avatar
u214892 committed
74
75
76
77
78
79

    def _get_one_hot_for_agent_direction(self, agent):
        """Retuns the agent's direction to one-hot encoding."""
        direction = np.zeros(4)
        direction[agent.direction] = 1
        return direction
80

81

82
83
84
85
86
87
88
class DummyObservationBuilder(ObservationBuilder):
    """
    DummyObservationBuilder class which returns dummy observations
    This is used in the evaluation service
    """

    def __init__(self):
89
        super().__init__()
90
91
92
93

    def reset(self):
        pass

94
    def get_many(self, handles: Optional[List[int]] = None) -> bool:
95
96
        return True

97
    def get(self, handle: int = 0) -> bool:
98
        return True