agent_utils.py 4.9 KB
Newer Older
u214892's avatar
u214892 committed
1
from enum import IntEnum
2
from itertools import starmap
3
from typing import Tuple
4

5
import numpy as np
6
from attr import attrs, attrib, Factory
7

8
9
from flatland.core.grid.grid4 import Grid4TransitionsEnum

10

u214892's avatar
u214892 committed
11
12
13
14
15
16
class RailAgentStatus(IntEnum):
    READY_TO_DEPART = 0
    ACTIVE = 1
    DONE = 2


hagrid67's avatar
hagrid67 committed
17
18
@attrs
class EnvAgentStatic(object):
19
    """ EnvAgentStatic - Stores initial position, direction and target.
hagrid67's avatar
hagrid67 committed
20
21
22
23
        This is like static data for the environment - it's where an agent starts,
        rather than where it is at the moment.
        The target should also be stored here.
    """
24
25
26
27
    position = attrib(type=Tuple[int, int])
    direction = attrib(type=Grid4TransitionsEnum)
    target = attrib(type=Tuple[int, int])
    moving = attrib(default=False, type=bool)
u214892's avatar
u214892 committed
28
    # position = attrib(default=None,type=Optional[Tuple[int, int]])
29

30
31
32
    # speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0,
    # after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous
    # cell if speed=1, as default)
33
34
35
    # N.B. we need to use factory since default arguments are not recreated on each call!
    speed_data = attrib(
        default=Factory(lambda: dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0})))
36

37
38
    # if broken>0, the agent's actions are ignored for 'broken' steps
    # number of time the agent had to stop, since the last time it broke down
39
    malfunction_data = attrib(
40
        default=Factory(
41
42
            lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0,
                          'moving_before_malfunction': False})))
43

u214892's avatar
u214892 committed
44
45
    status = attrib(default=RailAgentStatus.READY_TO_DEPART, type=RailAgentStatus)

hagrid67's avatar
hagrid67 committed
46
    @classmethod
47
    def from_lists(cls, positions, directions, targets, speeds=None, malfunction_rates=None):
hagrid67's avatar
hagrid67 committed
48
49
        """ Create a list of EnvAgentStatics from lists of positions, directions and targets
        """
50
        speed_datas = []
51

52
        for i in range(len(positions)):
spiglerg's avatar
spiglerg committed
53
54
55
            speed_datas.append({'position_fraction': 0.0,
                                'speed': speeds[i] if speeds is not None else 1.0,
                                'transition_action_on_cellexit': 0})
56

spiglerg's avatar
flake8    
spiglerg committed
57
58
        # TODO: on initialization, all agents are re-set as non-broken. Perhaps it may be desirable to set
        # some as broken?
59

60
        malfunction_datas = []
61
        for i in range(len(positions)):
62
            malfunction_datas.append({'malfunction': 0,
63
                                      'malfunction_rate': malfunction_rates[i] if malfunction_rates is not None else 0.,
64
65
                                      'next_malfunction': 0,
                                      'nr_malfunctions': 0})
66

spiglerg's avatar
flake8    
spiglerg committed
67
68
69
70
71
        return list(starmap(EnvAgentStatic, zip(positions,
                                                directions,
                                                targets,
                                                [False] * len(positions),
                                                speed_datas,
72
                                                malfunction_datas)))
maljx's avatar
maljx committed
73
74

    def to_list(self):
75
76
77
78
79
80
81
82
83
84
85

        # I can't find an expression which works on both tuples, lists and ndarrays
        # which converts them all to a list of native python ints.
        lPos = self.position
        if type(lPos) is np.ndarray:
            lPos = lPos.tolist()

        lTarget = self.target
        if type(lTarget) is np.ndarray:
            lTarget = lTarget.tolist()

86
        return [lPos, int(self.direction), lTarget, int(self.moving), self.speed_data, self.malfunction_data]
maljx's avatar
maljx committed
87

hagrid67's avatar
hagrid67 committed
88

89
@attrs
hagrid67's avatar
hagrid67 committed
90
class EnvAgent(EnvAgentStatic):
91
    """ EnvAgent - replace separate agent_* lists with a single list
hagrid67's avatar
hagrid67 committed
92
        of agent objects.  The EnvAgent represent's the environment's view
93
94
95
        of the dynamic agent state.
        We are duplicating target in the EnvAgent, which seems simpler than
        forcing the env to refer to it in the EnvAgentStatic
hagrid67's avatar
hagrid67 committed
96
    """
97
    handle = attrib(default=None)
maljx's avatar
maljx committed
98
    old_direction = attrib(default=None)
99
    old_position = attrib(default=None)
maljx's avatar
maljx committed
100
101

    def to_list(self):
102
        return [
103
            self.position, self.direction, self.target, self.handle,
104
            self.old_direction, self.old_position, self.moving, self.speed_data, self.malfunction_data]
hagrid67's avatar
hagrid67 committed
105

106
107
108
109
    @classmethod
    def from_static(cls, oStatic):
        """ Create an EnvAgent from the EnvAgentStatic,
        copying all the fields, and adding handle with the default 0.
hagrid67's avatar
hagrid67 committed
110
        """
111
        return EnvAgent(*oStatic.__dict__, handle=0)
hagrid67's avatar
hagrid67 committed
112

113
114
115
116
    @classmethod
    def list_from_static(cls, lEnvAgentStatic, handles=None):
        """ Create an EnvAgent from the EnvAgentStatic,
        copying all the fields, and adding handle with the default 0.
hagrid67's avatar
hagrid67 committed
117
        """
118
119
        if handles is None:
            handles = range(len(lEnvAgentStatic))
120

121
122
        return [EnvAgent(**oEAS.__dict__, handle=handle)
                for handle, oEAS in zip(handles, lEnvAgentStatic)]