agent_utils.py 4.18 KB
Newer Older
1
from itertools import starmap
2

3
import numpy as np
4
5
6
from attr import attrs, attrib


7
8
@attrs
class EnvDescription(object):
9
10
11
12
13
    """ EnvDescription - This is a description of a random env,
        based around the rail_generator and stats like size and n_agents.
        It mirrors the parameters given to the RailEnv constructor.
        Not currently used.
    """
14
15
16
17
    n_agents = attrib()
    height = attrib()
    width = attrib()
    rail_generator = attrib()
18
    obs_builder = attrib()  # not sure if this should closer to the agent than the env
19

hagrid67's avatar
hagrid67 committed
20
21
22

@attrs
class EnvAgentStatic(object):
23
    """ EnvAgentStatic - Stores initial position, direction and target.
hagrid67's avatar
hagrid67 committed
24
25
26
27
28
29
30
        This is like static data for the environment - it's where an agent starts,
        rather than where it is at the moment.
        The target should also be stored here.
    """
    position = attrib()
    direction = attrib()
    target = attrib()
31
32
33
34
35
36
37
38
39
40
41
42
    moving = attrib(default=False)
    # speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0,
    # after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous
    # cell if speed=1, as default)
    speed_data = attrib(default=dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0}))

    def __init__(self,
                 position,
                 direction,
                 target,
                 moving=False,
                 speed_data={'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0}):
maljx's avatar
maljx committed
43
44
45
        self.position = position
        self.direction = direction
        self.target = target
spiglerg's avatar
fix?    
spiglerg committed
46
        self.moving = moving
47
        self.speed_data = speed_data
hagrid67's avatar
hagrid67 committed
48

hagrid67's avatar
hagrid67 committed
49
    @classmethod
50
    def from_lists(cls, positions, directions, targets):
hagrid67's avatar
hagrid67 committed
51
52
        """ Create a list of EnvAgentStatics from lists of positions, directions and targets
        """
53
54
55
56
        speed_datas = []
        for i in range(len(positions)):
            speed_datas.append({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0})
        return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions), speed_datas)))
maljx's avatar
maljx committed
57
58

    def to_list(self):
59
60
61
62
63
64
65
66
67
68
69

        # I can't find an expression which works on both tuples, lists and ndarrays
        # which converts them all to a list of native python ints.
        lPos = self.position
        if type(lPos) is np.ndarray:
            lPos = lPos.tolist()

        lTarget = self.target
        if type(lTarget) is np.ndarray:
            lTarget = lTarget.tolist()

70
        return [lPos, int(self.direction), lTarget, int(self.moving), self.speed_data]
maljx's avatar
maljx committed
71

hagrid67's avatar
hagrid67 committed
72

73
@attrs
hagrid67's avatar
hagrid67 committed
74
class EnvAgent(EnvAgentStatic):
75
    """ EnvAgent - replace separate agent_* lists with a single list
hagrid67's avatar
hagrid67 committed
76
        of agent objects.  The EnvAgent represent's the environment's view
77
78
79
        of the dynamic agent state.
        We are duplicating target in the EnvAgent, which seems simpler than
        forcing the env to refer to it in the EnvAgentStatic
hagrid67's avatar
hagrid67 committed
80
    """
81
    handle = attrib(default=None)
maljx's avatar
maljx committed
82
    old_direction = attrib(default=None)
83
    old_position = attrib(default=None)
maljx's avatar
maljx committed
84

85
    def __init__(self, position, direction, target, handle, old_direction, old_position):
maljx's avatar
maljx committed
86
87
        super(EnvAgent, self).__init__(position, direction, target)
        self.handle = handle
88
        self.old_direction = old_direction
89
        self.old_position = old_position
maljx's avatar
maljx committed
90
91

    def to_list(self):
92
        return [
93
            self.position, self.direction, self.target, self.handle,
94
            self.old_direction, self.old_position, self.moving, self.speed_data]
hagrid67's avatar
hagrid67 committed
95

96
97
98
99
    @classmethod
    def from_static(cls, oStatic):
        """ Create an EnvAgent from the EnvAgentStatic,
        copying all the fields, and adding handle with the default 0.
hagrid67's avatar
hagrid67 committed
100
        """
101
        return EnvAgent(*oStatic.__dict__, handle=0)
hagrid67's avatar
hagrid67 committed
102

103
104
105
106
    @classmethod
    def list_from_static(cls, lEnvAgentStatic, handles=None):
        """ Create an EnvAgent from the EnvAgentStatic,
        copying all the fields, and adding handle with the default 0.
hagrid67's avatar
hagrid67 committed
107
        """
108
109
        if handles is None:
            handles = range(len(lEnvAgentStatic))
110

111
112
        return [EnvAgent(**oEAS.__dict__, handle=handle)
                for handle, oEAS in zip(handles, lEnvAgentStatic)]