agent_utils.py 3.18 KB
Newer Older
1
from itertools import starmap
2

3
import numpy as np
4
5
6
from attr import attrs, attrib


hagrid67's avatar
hagrid67 committed
7
8
@attrs
class EnvAgentStatic(object):
9
    """ EnvAgentStatic - Stores initial position, direction and target.
hagrid67's avatar
hagrid67 committed
10
11
12
13
14
15
16
        This is like static data for the environment - it's where an agent starts,
        rather than where it is at the moment.
        The target should also be stored here.
    """
    position = attrib()
    direction = attrib()
    target = attrib()
17
18
19
20
21
22
    moving = attrib(default=False)
    # speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0,
    # after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous
    # cell if speed=1, as default)
    speed_data = attrib(default=dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0}))

hagrid67's avatar
hagrid67 committed
23
    @classmethod
spiglerg's avatar
spiglerg committed
24
    def from_lists(cls, positions, directions, targets, speeds=None):
hagrid67's avatar
hagrid67 committed
25
26
        """ Create a list of EnvAgentStatics from lists of positions, directions and targets
        """
27
28
        speed_datas = []
        for i in range(len(positions)):
spiglerg's avatar
spiglerg committed
29
30
31
            speed_datas.append({'position_fraction': 0.0,
                                'speed': speeds[i] if speeds is not None else 1.0,
                                'transition_action_on_cellexit': 0})
32
        return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions), speed_datas)))
maljx's avatar
maljx committed
33
34

    def to_list(self):
35
36
37
38
39
40
41
42
43
44
45

        # I can't find an expression which works on both tuples, lists and ndarrays
        # which converts them all to a list of native python ints.
        lPos = self.position
        if type(lPos) is np.ndarray:
            lPos = lPos.tolist()

        lTarget = self.target
        if type(lTarget) is np.ndarray:
            lTarget = lTarget.tolist()

46
        return [lPos, int(self.direction), lTarget, int(self.moving), self.speed_data]
maljx's avatar
maljx committed
47

hagrid67's avatar
hagrid67 committed
48

49
@attrs
hagrid67's avatar
hagrid67 committed
50
class EnvAgent(EnvAgentStatic):
51
    """ EnvAgent - replace separate agent_* lists with a single list
hagrid67's avatar
hagrid67 committed
52
        of agent objects.  The EnvAgent represent's the environment's view
53
54
55
        of the dynamic agent state.
        We are duplicating target in the EnvAgent, which seems simpler than
        forcing the env to refer to it in the EnvAgentStatic
hagrid67's avatar
hagrid67 committed
56
    """
57
    handle = attrib(default=None)
maljx's avatar
maljx committed
58
    old_direction = attrib(default=None)
59
    old_position = attrib(default=None)
maljx's avatar
maljx committed
60
61

    def to_list(self):
62
        return [
63
            self.position, self.direction, self.target, self.handle,
64
            self.old_direction, self.old_position, self.moving, self.speed_data]
hagrid67's avatar
hagrid67 committed
65

66
67
68
69
    @classmethod
    def from_static(cls, oStatic):
        """ Create an EnvAgent from the EnvAgentStatic,
        copying all the fields, and adding handle with the default 0.
hagrid67's avatar
hagrid67 committed
70
        """
71
        return EnvAgent(*oStatic.__dict__, handle=0)
hagrid67's avatar
hagrid67 committed
72

73
74
75
76
    @classmethod
    def list_from_static(cls, lEnvAgentStatic, handles=None):
        """ Create an EnvAgent from the EnvAgentStatic,
        copying all the fields, and adding handle with the default 0.
hagrid67's avatar
hagrid67 committed
77
        """
78
79
        if handles is None:
            handles = range(len(lEnvAgentStatic))
80

81
82
        return [EnvAgent(**oEAS.__dict__, handle=handle)
                for handle, oEAS in zip(handles, lEnvAgentStatic)]