agent_utils.py 9.03 KB
Newer Older
1
from flatland.envs.rail_trainrun_data_structures import Waypoint
2
3
import numpy as np

u214892's avatar
u214892 committed
4
from enum import IntEnum
5
from flatland.envs.step_utils.states import TrainState
6
from itertools import starmap
7
from typing import Tuple, Optional, NamedTuple, List
8

9
from attr import attr, attrs, attrib, Factory
10

11
from flatland.core.grid.grid4 import Grid4TransitionsEnum
12
from flatland.envs.schedule_utils import Schedule
13

14
15
from flatland.envs.step_utils.action_saver import ActionSaver
from flatland.envs.step_utils.speed_counter import SpeedCounter
16
from flatland.envs.step_utils.state_machine import TrainStateMachine
Dipam Chakraborty's avatar
Dipam Chakraborty committed
17
from flatland.envs.step_utils.malfunction_handler import MalfunctionHandler
u214892's avatar
u214892 committed
18

19
20
21
22
23
Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
                             ('initial_direction', Grid4TransitionsEnum),
                             ('direction', Grid4TransitionsEnum),
                             ('target', Tuple[int, int]),
                             ('moving', bool),
24
25
                             ('earliest_departure', int),
                             ('latest_arrival', int),
26
27
28
29
                             ('speed_data', dict),
                             ('malfunction_data', dict),
                             ('handle', int),
                             ('position', Tuple[int, int]),
30
                             ('arrival_time', int),
31
                             ('old_direction', Grid4TransitionsEnum),
32
33
34
35
                             ('old_position', Tuple[int, int]),
                             ('speed_counter', SpeedCounter),
                             ('action_saver', ActionSaver),
                             ('state', TrainState),
36
                             ('state_machine', TrainStateMachine),
Dipam Chakraborty's avatar
Dipam Chakraborty committed
37
                             ('malfunction_handler', MalfunctionHandler),
38
                             ])
39
40


hagrid67's avatar
hagrid67 committed
41
@attrs
u229589's avatar
u229589 committed
42
class EnvAgent:
43
    # INIT FROM HERE IN _from_line()
u214892's avatar
u214892 committed
44
    initial_position = attrib(type=Tuple[int, int])
u229589's avatar
u229589 committed
45
    initial_direction = attrib(type=Grid4TransitionsEnum)
46
47
48
    direction = attrib(type=Grid4TransitionsEnum)
    target = attrib(type=Tuple[int, int])
    moving = attrib(default=False, type=bool)
49

50
    # NEW : EnvAgent - Schedule properties
51
52
    earliest_departure = attrib(default=None, type=int)  # default None during _from_line()
    latest_arrival = attrib(default=None, type=int)  # default None during _from_line()
53

54
55
56
    # speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0,
    # after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous
    # cell if speed=1, as default)
57
58
59
    # N.B. we need to use factory since default arguments are not recreated on each call!
    speed_data = attrib(
        default=Factory(lambda: dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0})))
60

61
62
    # if broken>0, the agent's actions are ignored for 'broken' steps
    # number of time the agent had to stop, since the last time it broke down
63
    malfunction_data = attrib(
64
        default=Factory(
65
            lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0,
66
                          'moving_before_malfunction': False})))
67

u229589's avatar
u229589 committed
68
    handle = attrib(default=None)
69
    # INIT TILL HERE IN _from_line()
u229589's avatar
u229589 committed
70

71
    # Env step facelift
Dipam Chakraborty's avatar
Dipam Chakraborty committed
72
73
74
75
76
    speed_counter = attrib(default = None, type=SpeedCounter)
    action_saver = attrib(default = Factory(lambda: ActionSaver()), type=ActionSaver)
    state_machine = attrib(default= Factory(lambda: TrainStateMachine(initial_state=TrainState.WAITING)) , 
                           type=TrainStateMachine)
    malfunction_handler = attrib(default = Factory(lambda: MalfunctionHandler()), type=MalfunctionHandler)
77
    
78
79
    state = attrib(default=TrainState.WAITING, type=TrainState)

u214892's avatar
u214892 committed
80
    position = attrib(default=None, type=Optional[Tuple[int, int]])
u214892's avatar
u214892 committed
81

82
83
84
    # NEW : EnvAgent Reward Handling
    arrival_time = attrib(default=None, type=int)

u229589's avatar
u229589 committed
85
86
87
88
    # used in rendering
    old_direction = attrib(default=None)
    old_position = attrib(default=None)

89

u229589's avatar
u229589 committed
90
    def reset(self):
Erik Nygren's avatar
Erik Nygren committed
91
        """
92
        Resets the agents to their initial values of the episode. Called after ScheduleTime generation.
Erik Nygren's avatar
Erik Nygren committed
93
        """
u229589's avatar
u229589 committed
94
        self.position = None
95
        # TODO: set direction to None: https://gitlab.aicrowd.com/flatland/flatland/issues/280
u229589's avatar
u229589 committed
96
97
98
99
100
        self.direction = self.initial_direction
        self.old_position = None
        self.old_direction = None
        self.moving = False

Erik Nygren's avatar
Erik Nygren committed
101
102
103
104
105
106
107
108
109
        # Reset agent values for speed
        self.speed_data['position_fraction'] = 0.
        self.speed_data['transition_action_on_cellexit'] = 0.

        # Reset agent malfunction values
        self.malfunction_data['malfunction'] = 0
        self.malfunction_data['nr_malfunctions'] = 0
        self.malfunction_data['moving_before_malfunction'] = False

110
111
        self.action_saver.clear_saved_action()
        self.speed_counter.reset_counter()
112
        self.state_machine.reset()
113

114
    def to_agent(self) -> Agent:
115
116
117
118
119
120
121
122
123
124
        return Agent(initial_position=self.initial_position, 
                     initial_direction=self.initial_direction,
                     direction=self.direction,
                     target=self.target,
                     moving=self.moving,
                     earliest_departure=self.earliest_departure, 
                     latest_arrival=self.latest_arrival, 
                     speed_data=self.speed_data,
                     malfunction_data=self.malfunction_data, 
                     handle=self.handle, 
Dipam Chakraborty's avatar
Dipam Chakraborty committed
125
                     state=self.state,
126
127
128
129
                     position=self.position, 
                     old_direction=self.old_direction, 
                     old_position=self.old_position,
                     speed_counter=self.speed_counter,
130
                     action_saver=self.action_saver,
Dipam Chakraborty's avatar
Dipam Chakraborty committed
131
132
                     state_machine=self.state_machine,
                     malfunction_handler=self.malfunction_handler)
u229589's avatar
u229589 committed
133

hagrid67's avatar
hagrid67 committed
134
    @classmethod
135
    def from_line(cls, line: Line):
u229589's avatar
u229589 committed
136
        """ Create a list of EnvAgent from lists of positions, directions and targets
hagrid67's avatar
hagrid67 committed
137
        """
138
        speed_datas = []
Dipam Chakraborty's avatar
Dipam Chakraborty committed
139
        speed_counters = []
140
        for i in range(len(schedule.agent_positions)):
Dipam Chakraborty's avatar
Dipam Chakraborty committed
141
            speed = schedule.agent_speeds[i] if schedule.agent_speeds is not None else 1.0
spiglerg's avatar
spiglerg committed
142
            speed_datas.append({'position_fraction': 0.0,
Dipam Chakraborty's avatar
Dipam Chakraborty committed
143
                                'speed': speed,
spiglerg's avatar
spiglerg committed
144
                                'transition_action_on_cellexit': 0})
Dipam Chakraborty's avatar
Dipam Chakraborty committed
145
            speed_counters.append( SpeedCounter(speed=speed) )
146

147
        malfunction_datas = []
148
        for i in range(len(line.agent_positions)):
149
            malfunction_datas.append({'malfunction': 0,
150
151
                                      'malfunction_rate': line.agent_malfunction_rates[
                                          i] if line.agent_malfunction_rates is not None else 0.,
152
                                      'next_malfunction': 0,
153
                                      'nr_malfunctions': 0})
154
155
        
        return list(starmap(EnvAgent, zip(schedule.agent_positions,  # TODO : Dipam - Really want to change this way of loading agents
u229589's avatar
u229589 committed
156
157
                                          schedule.agent_directions,
                                          schedule.agent_directions,
158
159
160
161
                                          schedule.agent_targets, 
                                          [False] * len(schedule.agent_positions), 
                                          [None] * len(schedule.agent_positions), # earliest_departure
                                          [None] * len(schedule.agent_positions), # latest_arrival
u229589's avatar
u229589 committed
162
163
                                          speed_datas,
                                          malfunction_datas,
164
165
166
                                          range(len(schedule.agent_positions)),
                                          speed_counters,
                                          )))
167
168
169
170
171

    @classmethod
    def load_legacy_static_agent(cls, static_agents_data: Tuple):
        agents = []
        for i, static_agent in enumerate(static_agents_data):
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
            if len(static_agent) >= 6:
                agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
                                direction=static_agent[1], target=static_agent[2], moving=static_agent[3],
                                speed_data=static_agent[4], malfunction_data=static_agent[5], handle=i)
            else:
                agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
                                direction=static_agent[1], target=static_agent[2], 
                                moving=False,
                                speed_data={"speed":1., "position_fraction":0., "transition_action_on_cell_exit":0.},
                                malfunction_data={
                                            'malfunction': 0,
                                            'nr_malfunctions': 0,
                                            'moving_before_malfunction': False
                                        },
                                handle=i)
187
188
            agents.append(agent)
        return agents