agent_utils.py 9.79 KB
Newer Older
1
2
from flatland.envs.rail_trainrun_data_structures import Waypoint
import numpy as np
3
import warnings
4
5

from typing import Tuple, Optional, NamedTuple, List
6

7
from attr import attr, attrs, attrib, Factory
8

9
from flatland.core.grid.grid4 import Grid4TransitionsEnum
10
from flatland.envs.timetable_utils import Line
11

12
13
from flatland.envs.step_utils.action_saver import ActionSaver
from flatland.envs.step_utils.speed_counter import SpeedCounter
14
from flatland.envs.step_utils.state_machine import TrainStateMachine
15
from flatland.envs.step_utils.states import TrainState
Dipam Chakraborty's avatar
Dipam Chakraborty committed
16
from flatland.envs.step_utils.malfunction_handler import MalfunctionHandler
u214892's avatar
u214892 committed
17

18
19
20
21
22
Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
                             ('initial_direction', Grid4TransitionsEnum),
                             ('direction', Grid4TransitionsEnum),
                             ('target', Tuple[int, int]),
                             ('moving', bool),
23
24
                             ('earliest_departure', int),
                             ('latest_arrival', int),
25
26
27
                             ('malfunction_data', dict),
                             ('handle', int),
                             ('position', Tuple[int, int]),
28
                             ('arrival_time', int),
29
                             ('old_direction', Grid4TransitionsEnum),
30
31
32
33
                             ('old_position', Tuple[int, int]),
                             ('speed_counter', SpeedCounter),
                             ('action_saver', ActionSaver),
                             ('state', TrainState),
34
                             ('state_machine', TrainStateMachine),
Dipam Chakraborty's avatar
Dipam Chakraborty committed
35
                             ('malfunction_handler', MalfunctionHandler),
36
                             ])
37
38


hagrid67's avatar
hagrid67 committed
39
@attrs
u229589's avatar
u229589 committed
40
class EnvAgent:
41
    # INIT FROM HERE IN _from_line()
u214892's avatar
u214892 committed
42
    initial_position = attrib(type=Tuple[int, int])
u229589's avatar
u229589 committed
43
    initial_direction = attrib(type=Grid4TransitionsEnum)
44
45
46
    direction = attrib(type=Grid4TransitionsEnum)
    target = attrib(type=Tuple[int, int])
    moving = attrib(default=False, type=bool)
47

48
    # NEW : EnvAgent - Schedule properties
49
50
    earliest_departure = attrib(default=None, type=int)  # default None during _from_line()
    latest_arrival = attrib(default=None, type=int)  # default None during _from_line()
51

52
53
    # if broken>0, the agent's actions are ignored for 'broken' steps
    # number of time the agent had to stop, since the last time it broke down
54
    malfunction_data = attrib(
55
        default=Factory(
56
            lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0,
57
                          'moving_before_malfunction': False})))
58

u229589's avatar
u229589 committed
59
    handle = attrib(default=None)
60
    # INIT TILL HERE IN _from_line()
u229589's avatar
u229589 committed
61

62
    # Env step facelift
63
    speed_counter = attrib(default = Factory(lambda: SpeedCounter(1.0)), type=SpeedCounter)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
64
65
66
67
    action_saver = attrib(default = Factory(lambda: ActionSaver()), type=ActionSaver)
    state_machine = attrib(default= Factory(lambda: TrainStateMachine(initial_state=TrainState.WAITING)) , 
                           type=TrainStateMachine)
    malfunction_handler = attrib(default = Factory(lambda: MalfunctionHandler()), type=MalfunctionHandler)
68

u214892's avatar
u214892 committed
69
    position = attrib(default=None, type=Optional[Tuple[int, int]])
u214892's avatar
u214892 committed
70

71
72
73
    # NEW : EnvAgent Reward Handling
    arrival_time = attrib(default=None, type=int)

u229589's avatar
u229589 committed
74
75
76
77
    # used in rendering
    old_direction = attrib(default=None)
    old_position = attrib(default=None)

78

u229589's avatar
u229589 committed
79
    def reset(self):
Erik Nygren's avatar
Erik Nygren committed
80
        """
81
        Resets the agents to their initial values of the episode. Called after ScheduleTime generation.
Erik Nygren's avatar
Erik Nygren committed
82
        """
u229589's avatar
u229589 committed
83
        self.position = None
84
        # TODO: set direction to None: https://gitlab.aicrowd.com/flatland/flatland/issues/280
u229589's avatar
u229589 committed
85
86
87
88
89
        self.direction = self.initial_direction
        self.old_position = None
        self.old_direction = None
        self.moving = False

Erik Nygren's avatar
Erik Nygren committed
90
91
92
93
94
        # Reset agent malfunction values
        self.malfunction_data['malfunction'] = 0
        self.malfunction_data['nr_malfunctions'] = 0
        self.malfunction_data['moving_before_malfunction'] = False

95
96
        self.action_saver.clear_saved_action()
        self.speed_counter.reset_counter()
97
        self.state_machine.reset()
98

99
    def to_agent(self) -> Agent:
100
101
102
103
104
105
106
107
108
        return Agent(initial_position=self.initial_position, 
                     initial_direction=self.initial_direction,
                     direction=self.direction,
                     target=self.target,
                     moving=self.moving,
                     earliest_departure=self.earliest_departure, 
                     latest_arrival=self.latest_arrival, 
                     malfunction_data=self.malfunction_data, 
                     handle=self.handle, 
Dipam Chakraborty's avatar
Dipam Chakraborty committed
109
                     state=self.state,
110
111
112
113
                     position=self.position, 
                     old_direction=self.old_direction, 
                     old_position=self.old_position,
                     speed_counter=self.speed_counter,
114
                     action_saver=self.action_saver,
Dipam Chakraborty's avatar
Dipam Chakraborty committed
115
116
                     state_machine=self.state_machine,
                     malfunction_handler=self.malfunction_handler)
u229589's avatar
u229589 committed
117

118
119
120
121
122
123
124
125
126
127
    def get_shortest_path(self, distance_map) -> List[Waypoint]:
        from flatland.envs.rail_env_shortest_paths import get_shortest_paths # Circular dep fix
        return get_shortest_paths(distance_map=distance_map, agent_handle=self.handle)[self.handle]
        
    def get_travel_time_on_shortest_path(self, distance_map) -> int:
        shortest_path = self.get_shortest_path(distance_map)
        if shortest_path is not None:
            distance = len(shortest_path)
        else:
            distance = 0
128
        speed = self.speed_counter.speed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        return int(np.ceil(distance / speed))

    def get_time_remaining_until_latest_arrival(self, elapsed_steps: int) -> int:
        return self.latest_arrival - elapsed_steps

    def get_current_delay(self, elapsed_steps: int, distance_map) -> int:
        '''
        +ve if arrival time is projected before latest arrival
        -ve if arrival time is projected after latest arrival
        '''
        return self.get_time_remaining_until_latest_arrival(elapsed_steps) - \
               self.get_travel_time_on_shortest_path(distance_map)


hagrid67's avatar
hagrid67 committed
143
    @classmethod
144
    def from_line(cls, line: Line):
u229589's avatar
u229589 committed
145
        """ Create a list of EnvAgent from lists of positions, directions and targets
hagrid67's avatar
hagrid67 committed
146
        """
147
        num_agents = len(line.agent_positions)
148
        
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        agent_list = []
        for i_agent in range(num_agents):
            speed = line.agent_speeds[i_agent] if line.agent_speeds is not None else 1.0
            
            if line.agent_malfunction_rates is not None:
                malfunction_rate = line.agent_malfunction_rates[i_agent]
            else:
                malfunction_rate = 0.
            
            malfunction_data = {'malfunction': 0,
                                'malfunction_rate': malfunction_rate,
                                'next_malfunction': 0,
                                'nr_malfunctions': 0
                               }
            agent = EnvAgent(initial_position = line.agent_positions[i_agent],
                            initial_direction = line.agent_directions[i_agent],
                            direction = line.agent_directions[i_agent],
                            target = line.agent_targets[i_agent], 
                            moving = False, 
                            earliest_departure = None,
                            latest_arrival = None,
                            malfunction_data = malfunction_data,
                            handle = i_agent,
                            speed_counter = SpeedCounter(speed=speed))
            agent_list.append(agent)

        return agent_list
176
177
178

    @classmethod
    def load_legacy_static_agent(cls, static_agents_data: Tuple):
179
        raise NotImplementedError("Not implemented for Flatland 3")
180
181
        agents = []
        for i, static_agent in enumerate(static_agents_data):
182
183
184
185
186
187
188
189
190
191
192
193
194
            if len(static_agent) >= 6:
                agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
                                direction=static_agent[1], target=static_agent[2], moving=static_agent[3],
                                speed_data=static_agent[4], malfunction_data=static_agent[5], handle=i)
            else:
                agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
                                direction=static_agent[1], target=static_agent[2], 
                                moving=False,
                                malfunction_data={
                                            'malfunction': 0,
                                            'nr_malfunctions': 0,
                                            'moving_before_malfunction': False
                                        },
195
                                speed_counter=SpeedCounter(1.0),
196
                                handle=i)
197
198
            agents.append(agent)
        return agents
199
    
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    def _set_state(self, state):
        warnings.warn("Not recommended to set the state with this function unless completely required")
        self.state_machine.set_state(state)
    
    def __str__(self):
        return f"\n \
                 handle(agent index): {self.handle} \n \
                 initial_position: {self.initial_position}   initial_direction: {self.initial_direction} \n \
                 position: {self.position}  direction: {self.position}  target: {self.target} \n \
                 earliest_departure: {self.earliest_departure}  latest_arrival: {self.latest_arrival} \n \
                 state: {str(self.state)} \n \
                 malfunction_data: {self.malfunction_data} \n \
                 action_saver: {self.action_saver} \n \
                 speed_counter: {self.speed_counter}"

215
216
217
    @property
    def state(self):
        return self.state_machine.state
218
219
220
221