env_observation_builder.py 2.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
"""
ObservationBuilder objects are objects that can be passed to environments designed for customizability.
The ObservationBuilder-derived custom classes implement 2 functions, reset() and get() or get(handle).

+ Reset() is called after each environment reset, to allow for pre-computing relevant data.

+ Get() is called whenever an observation has to be computed, potentially for each agent independently in
case of multi-agent environments.
"""
u214892's avatar
u214892 committed
10
import numpy as np
11

12
13

class ObservationBuilder:
14
15
    """
    ObservationBuilder base class.
16

u214892's avatar
u214892 committed
17
    Derived objects must implement and `observation_space' attribute as a tuple with the dimensions of the returned
18
    observations.
19
    """
Erik Nygren's avatar
Erik Nygren committed
20

21
    def __init__(self):
22
        self.observation_space = ()
23
24

    def _set_env(self, env):
25
26
27
        self.env = env

    def reset(self):
28
29
30
        """
        Called after each environment reset.
        """
31
32
        raise NotImplementedError()

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    def get_many(self, handles=[]):
        """
        Called whenever an observation has to be computed for the `env' environment, for each agent with handle
        in the `handles' list.

        Parameters
        -------
        handles : list of handles (optional)
            List with the handles of the agents for which to compute the observation vector.

        Returns
        -------
        function
            A dictionary of observation structures, specific to the corresponding environment, with handles from
            `handles' as keys.
        """
        observations = {}
        for h in handles:
            observations[h] = self.get(h)
        return observations

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    def get(self, handle=0):
        """
        Called whenever an observation has to be computed for the `env' environment, possibly
        for each agent independently (agent id `handle').

        Parameters
        -------
        handle : int (optional)
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
        function
            An observation structure, specific to the corresponding environment.
        """
69
        raise NotImplementedError()
u214892's avatar
u214892 committed
70
71
72
73
74
75

    def _get_one_hot_for_agent_direction(self, agent):
        """Retuns the agent's direction to one-hot encoding."""
        direction = np.zeros(4)
        direction[agent.direction] = 1
        return direction