agent_utils.py 11 KB
Newer Older
1
2
from flatland.envs.rail_trainrun_data_structures import Waypoint
import numpy as np
3
import warnings
4
5

from typing import Tuple, Optional, NamedTuple, List
6

7
from attr import attr, attrs, attrib, Factory
8

9
from flatland.core.grid.grid4 import Grid4TransitionsEnum
10
from flatland.envs.timetable_utils import Line
11

12
13
from flatland.envs.step_utils.action_saver import ActionSaver
from flatland.envs.step_utils.speed_counter import SpeedCounter
14
from flatland.envs.step_utils.state_machine import TrainStateMachine
15
from flatland.envs.step_utils.states import TrainState
Dipam Chakraborty's avatar
Dipam Chakraborty committed
16
from flatland.envs.step_utils.malfunction_handler import MalfunctionHandler
u214892's avatar
u214892 committed
17

18
19
20
21
22
Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
                             ('initial_direction', Grid4TransitionsEnum),
                             ('direction', Grid4TransitionsEnum),
                             ('target', Tuple[int, int]),
                             ('moving', bool),
23
24
                             ('earliest_departure', int),
                             ('latest_arrival', int),
25
26
27
                             ('malfunction_data', dict),
                             ('handle', int),
                             ('position', Tuple[int, int]),
28
                             ('arrival_time', int),
29
                             ('old_direction', Grid4TransitionsEnum),
30
31
32
                             ('old_position', Tuple[int, int]),
                             ('speed_counter', SpeedCounter),
                             ('action_saver', ActionSaver),
33
                             ('state_machine', TrainStateMachine),
Dipam Chakraborty's avatar
Dipam Chakraborty committed
34
                             ('malfunction_handler', MalfunctionHandler),
35
                             ])
36
37


38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def load_env_agent(agent_tuple: Agent):
     return EnvAgent(
                        initial_position = agent_tuple.initial_position,
                        initial_direction = agent_tuple.initial_direction,
                        direction = agent_tuple.direction,
                        target = agent_tuple.target,
                        moving = agent_tuple.moving,
                        earliest_departure = agent_tuple.earliest_departure,
                        latest_arrival = agent_tuple.latest_arrival,
                        handle = agent_tuple.handle,
                        position = agent_tuple.position,
                        arrival_time = agent_tuple.arrival_time,
                        old_direction = agent_tuple.old_direction,
                        old_position = agent_tuple.old_position,
                        speed_counter = agent_tuple.speed_counter,
                        action_saver = agent_tuple.action_saver,
                        state_machine = agent_tuple.state_machine,
                        malfunction_handler = agent_tuple.malfunction_handler,
                    )

hagrid67's avatar
hagrid67 committed
58
@attrs
u229589's avatar
u229589 committed
59
class EnvAgent:
60
    # INIT FROM HERE IN _from_line()
u214892's avatar
u214892 committed
61
    initial_position = attrib(type=Tuple[int, int])
u229589's avatar
u229589 committed
62
    initial_direction = attrib(type=Grid4TransitionsEnum)
63
64
65
    direction = attrib(type=Grid4TransitionsEnum)
    target = attrib(type=Tuple[int, int])
    moving = attrib(default=False, type=bool)
66

67
    # NEW : EnvAgent - Schedule properties
68
69
    earliest_departure = attrib(default=None, type=int)  # default None during _from_line()
    latest_arrival = attrib(default=None, type=int)  # default None during _from_line()
70

71
72
    # if broken>0, the agent's actions are ignored for 'broken' steps
    # number of time the agent had to stop, since the last time it broke down
73
    malfunction_data = attrib(
74
        default=Factory(
75
            lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0,
76
                          'moving_before_malfunction': False})))
77

u229589's avatar
u229589 committed
78
    handle = attrib(default=None)
79
    # INIT TILL HERE IN _from_line()
u229589's avatar
u229589 committed
80

81
    # Env step facelift
82
    speed_counter = attrib(default = Factory(lambda: SpeedCounter(1.0)), type=SpeedCounter)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
83
84
85
86
    action_saver = attrib(default = Factory(lambda: ActionSaver()), type=ActionSaver)
    state_machine = attrib(default= Factory(lambda: TrainStateMachine(initial_state=TrainState.WAITING)) , 
                           type=TrainStateMachine)
    malfunction_handler = attrib(default = Factory(lambda: MalfunctionHandler()), type=MalfunctionHandler)
87

u214892's avatar
u214892 committed
88
    position = attrib(default=None, type=Optional[Tuple[int, int]])
u214892's avatar
u214892 committed
89

90
91
92
    # NEW : EnvAgent Reward Handling
    arrival_time = attrib(default=None, type=int)

u229589's avatar
u229589 committed
93
94
95
96
    # used in rendering
    old_direction = attrib(default=None)
    old_position = attrib(default=None)

97

u229589's avatar
u229589 committed
98
    def reset(self):
Erik Nygren's avatar
Erik Nygren committed
99
        """
100
        Resets the agents to their initial values of the episode. Called after ScheduleTime generation.
Erik Nygren's avatar
Erik Nygren committed
101
        """
u229589's avatar
u229589 committed
102
        self.position = None
103
        # TODO: set direction to None: https://gitlab.aicrowd.com/flatland/flatland/issues/280
u229589's avatar
u229589 committed
104
105
106
107
108
        self.direction = self.initial_direction
        self.old_position = None
        self.old_direction = None
        self.moving = False

Erik Nygren's avatar
Erik Nygren committed
109
110
111
112
113
        # Reset agent malfunction values
        self.malfunction_data['malfunction'] = 0
        self.malfunction_data['nr_malfunctions'] = 0
        self.malfunction_data['moving_before_malfunction'] = False

114
115
        self.action_saver.clear_saved_action()
        self.speed_counter.reset_counter()
116
        self.state_machine.reset()
117

118
    def to_agent(self) -> Agent:
119
120
121
122
123
124
125
126
        return Agent(initial_position=self.initial_position, 
                     initial_direction=self.initial_direction,
                     direction=self.direction,
                     target=self.target,
                     moving=self.moving,
                     earliest_departure=self.earliest_departure, 
                     latest_arrival=self.latest_arrival, 
                     malfunction_data=self.malfunction_data, 
127
                     handle=self.handle,
128
129
130
131
                     position=self.position, 
                     old_direction=self.old_direction, 
                     old_position=self.old_position,
                     speed_counter=self.speed_counter,
132
                     action_saver=self.action_saver,
133
                     arrival_time=self.arrival_time,
Dipam Chakraborty's avatar
Dipam Chakraborty committed
134
135
                     state_machine=self.state_machine,
                     malfunction_handler=self.malfunction_handler)
u229589's avatar
u229589 committed
136

137
138
139
140
141
142
143
144
145
146
    def get_shortest_path(self, distance_map) -> List[Waypoint]:
        from flatland.envs.rail_env_shortest_paths import get_shortest_paths # Circular dep fix
        return get_shortest_paths(distance_map=distance_map, agent_handle=self.handle)[self.handle]
        
    def get_travel_time_on_shortest_path(self, distance_map) -> int:
        shortest_path = self.get_shortest_path(distance_map)
        if shortest_path is not None:
            distance = len(shortest_path)
        else:
            distance = 0
147
        speed = self.speed_counter.speed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
        return int(np.ceil(distance / speed))

    def get_time_remaining_until_latest_arrival(self, elapsed_steps: int) -> int:
        return self.latest_arrival - elapsed_steps

    def get_current_delay(self, elapsed_steps: int, distance_map) -> int:
        '''
        +ve if arrival time is projected before latest arrival
        -ve if arrival time is projected after latest arrival
        '''
        return self.get_time_remaining_until_latest_arrival(elapsed_steps) - \
               self.get_travel_time_on_shortest_path(distance_map)


hagrid67's avatar
hagrid67 committed
162
    @classmethod
163
    def from_line(cls, line: Line):
u229589's avatar
u229589 committed
164
        """ Create a list of EnvAgent from lists of positions, directions and targets
hagrid67's avatar
hagrid67 committed
165
        """
166
        num_agents = len(line.agent_positions)
167
        
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
        agent_list = []
        for i_agent in range(num_agents):
            speed = line.agent_speeds[i_agent] if line.agent_speeds is not None else 1.0
            
            if line.agent_malfunction_rates is not None:
                malfunction_rate = line.agent_malfunction_rates[i_agent]
            else:
                malfunction_rate = 0.
            
            malfunction_data = {'malfunction': 0,
                                'malfunction_rate': malfunction_rate,
                                'next_malfunction': 0,
                                'nr_malfunctions': 0
                               }
            agent = EnvAgent(initial_position = line.agent_positions[i_agent],
                            initial_direction = line.agent_directions[i_agent],
                            direction = line.agent_directions[i_agent],
                            target = line.agent_targets[i_agent], 
                            moving = False, 
                            earliest_departure = None,
                            latest_arrival = None,
                            malfunction_data = malfunction_data,
                            handle = i_agent,
                            speed_counter = SpeedCounter(speed=speed))
            agent_list.append(agent)

        return agent_list
195
196
197
198
199

    @classmethod
    def load_legacy_static_agent(cls, static_agents_data: Tuple):
        agents = []
        for i, static_agent in enumerate(static_agents_data):
200
201
202
            if len(static_agent) >= 6:
                agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
                                direction=static_agent[1], target=static_agent[2], moving=static_agent[3],
203
204
                                speed_counter=SpeedCounter(static_agent[4]['speed']), malfunction_data=static_agent[5], 
                                handle=i)
205
206
207
208
209
210
211
212
213
            else:
                agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
                                direction=static_agent[1], target=static_agent[2], 
                                moving=False,
                                malfunction_data={
                                            'malfunction': 0,
                                            'nr_malfunctions': 0,
                                            'moving_before_malfunction': False
                                        },
214
                                speed_counter=SpeedCounter(1.0),
215
                                handle=i)
216
217
            agents.append(agent)
        return agents
218
    
219
220
221
222
    def __str__(self):
        return f"\n \
                 handle(agent index): {self.handle} \n \
                 initial_position: {self.initial_position}   initial_direction: {self.initial_direction} \n \
223
                 position: {self.position}  direction: {self.direction}  target: {self.target} \n \
Dipam Chakraborty's avatar
Dipam Chakraborty committed
224
                 old_position: {self.old_position} old_direction {self.old_direction} \n \
225
226
                 earliest_departure: {self.earliest_departure}  latest_arrival: {self.latest_arrival} \n \
                 state: {str(self.state)} \n \
227
                 malfunction_handler: {self.malfunction_handler} \n \
228
229
230
                 action_saver: {self.action_saver} \n \
                 speed_counter: {self.speed_counter}"

231
232
233
    @property
    def state(self):
        return self.state_machine.state
234

Dipam Chakraborty's avatar
Dipam Chakraborty committed
235
236
237
238
239
240
241
242
    @state.setter
    def state(self, state):
        self._set_state(state)
    
    def _set_state(self, state):
        warnings.warn("Not recommended to set the state with this function unless completely required")
        self.state_machine.set_state(state)

243
244
245