env_prediction_builder.py 1.34 KB
Newer Older
u214892's avatar
u214892 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
PredictionBuilder objects are objects that can be passed to environments designed for customizability.
The PredictionBuilder-derived custom classes implement 2 functions, reset() and get([handle]).
If predictions are not required in every step or not for all agents, then

+ Reset() is called after each environment reset, to allow for pre-computing relevant data.

+ Get() is called whenever an step has to be computed, potentially for each agent independently in
case of multi-agent environments.
"""


class PredictionBuilder:
    """
    PredictionBuilder base class.
16

u214892's avatar
u214892 committed
17
18
    """

u214892's avatar
u214892 committed
19
20
    def __init__(self, max_depth: int = 20):
        self.max_depth = max_depth
u214892's avatar
u214892 committed
21
22
23
24
25
26
27
28

    def _set_env(self, env):
        self.env = env

    def reset(self):
        """
        Called after each environment reset.
        """
u214892's avatar
u214892 committed
29
        pass
u214892's avatar
u214892 committed
30

31
    def get(self, custom_args=None, handle=0):
u214892's avatar
u214892 committed
32
        """
33
        Called whenever get_many in the observation build is called.
u214892's avatar
u214892 committed
34
35
36

        Parameters
        -------
37
38
39
        custom_args: dict
            Implementation-dependent custom arguments, see the sub-classes.

u214892's avatar
u214892 committed
40
41
42
43
44
45
        handle : int (optional)
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
        function
u214892's avatar
u214892 committed
46
            A prediction structure, specific to the corresponding environment.
u214892's avatar
u214892 committed
47
48
        """
        raise NotImplementedError()