rail_env.py 17.6 KB
Newer Older
1
2
3
4
5
6
"""
Definition of the RailEnv environment and related level-generation functions.

Generator functions are functions that take width, height and num_resets as arguments and return
a GridTransitionMap object.
"""
hagrid67's avatar
hagrid67 committed
7
# TODO:  _ this is a global method --> utils or remove later
8

9
10
from enum import IntEnum

maljx's avatar
maljx committed
11
import msgpack
12
import numpy as np
13
14

from flatland.core.env import Environment
u214892's avatar
u214892 committed
15
from flatland.core.grid.grid4_utils import get_new_position
16
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
17
18
19
from flatland.envs.generators import random_rail_generator
from flatland.envs.observations import TreeObsForRailEnv

20

spiglerg's avatar
spiglerg committed
21
class RailEnvActions(IntEnum):
22
    DO_NOTHING = 0  # implies change of direction in a dead-end!
spiglerg's avatar
spiglerg committed
23
24
25
26
27
    MOVE_LEFT = 1
    MOVE_FORWARD = 2
    MOVE_RIGHT = 3
    STOP_MOVING = 4

28
29
30
31
32
33
34
35
36
37
    @staticmethod
    def to_char(a: int):
        return {
            0: 'B',
            1: 'L',
            2: 'F',
            3: 'R',
            4: 'S',
        }[a]

u214892's avatar
u214892 committed
38

39
40
41
42
43
44
45
46
47
48
49
class RailEnv(Environment):
    """
    RailEnv environment class.

    RailEnv is an environment inspired by a (simplified version of) a rail
    network, in which agents (trains) have to navigate to their target
    locations in the shortest time possible, while at the same time cooperating
    to avoid bottlenecks.

    The valid actions in the environment are:
        0: do nothing
spiglerg's avatar
spiglerg committed
50
51
52
53
        1: turn left and move to the next cell; if the agent was not moving, movement is started
        2: move to the next cell in front of the agent; if the agent was not moving, movement is started
        3: turn right and move to the next cell; if the agent was not moving, movement is started
        4: stop moving
54
55
56
57
58
59
60
61
62
63
64
65
66
67

    Moving forward in a dead-end cell makes the agent turn 180 degrees and step
    to the cell it came from.

    The actions of the agents are executed in order of their handle to prevent
    deadlocks and to allow them to learn relative priorities.

    TODO: WRITE ABOUT THE REWARD FUNCTION, and possibly allow for alpha and
    beta to be passed as parameters to __init__().
    """

    def __init__(self,
                 width,
                 height,
spiglerg's avatar
spiglerg committed
68
                 rail_generator=random_rail_generator(),
69
                 number_of_agents=1,
u214892's avatar
u214892 committed
70
71
                 obs_builder_object=TreeObsForRailEnv(max_depth=2),
                 ):
72
73
74
75
76
77
        """
        Environment init.

        Parameters
        -------
        rail_generator : function
78
79
80
81
            The rail_generator function is a function that takes the width,
            height and agents handles of a  rail environment, along with the number of times
            the env has been reset, and returns a GridTransitionMap object and a list of
            starting positions, targets, and initial orientations for agent handle.
82
83
84
85
            Implemented functions are:
                random_rail_generator : generate a random rail of given size
                rail_from_GridTransitionMap_generator(rail_map) : generate a rail from
                                        a GridTransitionMap object
86
                rail_from_manual_sp ecifications_generator(rail_spec) : generate a rail from
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
                                        a rail specifications array
                TODO: generate_rail_from_saved_list or from list of ndarray bitmaps ---
        width : int
            The width of the rail map. Potentially in the future,
            a range of widths to sample from.
        height : int
            The height of the rail map. Potentially in the future,
            a range of heights to sample from.
        number_of_agents : int
            Number of agents to spawn on the map. Potentially in the future,
            a range of number of agents to sample from.
        obs_builder_object: ObservationBuilder object
            ObservationBuilder-derived object that takes builds observation
            vectors for each agent.
        """

        self.rail_generator = rail_generator
        self.rail = None
        self.width = width
        self.height = height

        self.obs_builder = obs_builder_object
        self.obs_builder._set_env(self)

111
        self.action_space = [1]
u214892's avatar
u214892 committed
112
        self.observation_space = self.obs_builder.observation_space
113

114
        self.rewards = [0] * number_of_agents
115
116
        self.done = False

117
118
        self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False)

119
120
        self.obs_dict = {}
        self.rewards_dict = {}
Erik Nygren's avatar
Erik Nygren committed
121
        self.dev_obs_dict = {}
122

123
124
        self.agents = [None] * number_of_agents  # live agents
        self.agents_static = [None] * number_of_agents  # static agent information
125
126
        self.num_resets = 0
        self.reset()
127
        self.num_resets = 0  # yes, set it to zero again!
128

129
130
        self.valid_positions = None

131
    # no more agent_handles
132
    def get_agent_handles(self):
133
134
135
136
137
138
139
        return range(self.get_num_agents())

    def get_num_agents(self, static=True):
        if static:
            return len(self.agents_static)
        else:
            return len(self.agents)
140

hagrid67's avatar
hagrid67 committed
141
142
143
144
145
146
147
    def add_agent_static(self, agent_static):
        """ Add static info for a single agent.
            Returns the index of the new agent.
        """
        self.agents_static.append(agent_static)
        return len(self.agents_static) - 1

148
149
    def restart_agents(self):
        """ Reset the agents to their starting positions defined in agents_static
hagrid67's avatar
hagrid67 committed
150
        """
151
152
153
154
155
156
        self.agents = EnvAgent.list_from_static(self.agents_static)

    def reset(self, regen_rail=True, replace_agents=True):
        """ if regen_rail then regenerate the rails.
            if replace_agents then regenerate the agents static.
            Relies on the rail_generator returning agent_static lists (pos, dir, target)
hagrid67's avatar
hagrid67 committed
157
        """
158
        tRailAgents = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
hagrid67's avatar
hagrid67 committed
159

160
        if regen_rail or self.rail is None:
hagrid67's avatar
hagrid67 committed
161
            self.rail = tRailAgents[0]
162

hagrid67's avatar
hagrid67 committed
163
        if replace_agents:
spiglerg's avatar
spiglerg committed
164
            self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5])
hagrid67's avatar
hagrid67 committed
165

166
        self.restart_agents()
hagrid67's avatar
hagrid67 committed
167

Erik Nygren's avatar
Erik Nygren committed
168
169
170
171
        for iAgent in range(self.get_num_agents()):
            agent = self.agents[iAgent]
            agent.speed_data['position_fraction'] = 0.0

172
173
        self.num_resets += 1

u214892's avatar
u214892 committed
174
        # TODO perhaps dones should be part of each agent.
175
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
176

177
178
        # Reset the state of the observation builder with the new environment
        self.obs_builder.reset()
spiglerg's avatar
spiglerg committed
179
        self.observation_space = self.obs_builder.observation_space  # <-- change on reset?
180
181
182
183

        # Return the new observation vectors for each agent
        return self._get_observations()

spiglerg's avatar
spiglerg committed
184
185
186
    def step(self, action_dict_):
        action_dict = action_dict_.copy()

187
188
189
        alpha = 1.0
        beta = 1.0

190
        invalid_action_penalty = 0  # previously -2; GIACOMO: we decided that invalid actions will carry no penalty
191
192
        step_penalty = -1 * alpha
        global_reward = 1 * beta
193
194
        stop_penalty = 0  # penalty for stopping a moving agent
        start_penalty = 0  # penalty for starting a stopped agent
195
196

        # Reset the step rewards
197
        self.rewards_dict = dict()
u214892's avatar
u214892 committed
198
199
        for i_agent in range(self.get_num_agents()):
            self.rewards_dict[i_agent] = 0
200
201

        if self.dones["__all__"]:
spmohanty's avatar
spmohanty committed
202
            self.rewards_dict = {i: r + global_reward for i, r in self.rewards_dict.items()}
203
204
            return self._get_observations(), self.rewards_dict, self.dones, {}

u214892's avatar
u214892 committed
205
        for i_agent, agent in enumerate(self.agents):
206
207
            agent.old_direction = agent.direction
            agent.old_position = agent.position
u214892's avatar
u214892 committed
208
            if self.dones[i_agent]:  # this agent has already completed...
209
                continue
210

u214892's avatar
u214892 committed
211
212
            if i_agent not in action_dict:  # no action has been supplied for this agent
                action_dict[i_agent] = RailEnvActions.DO_NOTHING
213

u214892's avatar
u214892 committed
214
215
216
            if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions):
                print('ERROR: illegal action=', action_dict[i_agent],
                      'for agent with index=', i_agent,
217
                      '"DO NOTHING" will be executed instead')
u214892's avatar
u214892 committed
218
                action_dict[i_agent] = RailEnvActions.DO_NOTHING
219

u214892's avatar
u214892 committed
220
            action = action_dict[i_agent]
221

spiglerg's avatar
spiglerg committed
222
223
224
225
            if action == RailEnvActions.DO_NOTHING and agent.moving:
                # Keep moving
                action = RailEnvActions.MOVE_FORWARD

Erik Nygren's avatar
Erik Nygren committed
226
            if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] == 0.:
227
                # Only allow halting an agent on entering new cells.
spiglerg's avatar
spiglerg committed
228
                agent.moving = False
u214892's avatar
u214892 committed
229
                self.rewards_dict[i_agent] += stop_penalty
230

231
            if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
232
                # Only allow agent to start moving by pressing forward.
spiglerg's avatar
spiglerg committed
233
                agent.moving = True
u214892's avatar
u214892 committed
234
                self.rewards_dict[i_agent] += start_penalty
spiglerg's avatar
spiglerg committed
235

236
237
238
239
240
241
242
243
244
245
246
247
            # Now perform a movement.
            # If the agent is in an initial position within a new cell (agent.speed_data['position_fraction']<eps)
            #   store the desired action in `transition_action_on_cellexit' (only if the desired transition is
            #   allowed! otherwise DO_NOTHING!)
            # Then in any case (if agent.moving) and the `transition_action_on_cellexit' is valid, increment the
            #   position_fraction by the speed of the agent   (regardless of action taken, as long as no
            #   STOP_MOVING, but that makes agent.moving=False)
            # If the new position fraction is >= 1, reset to 0, and perform the stored
            #   transition_action_on_cellexit

            # If the agent can make an action
            action_selected = False
Erik Nygren's avatar
Erik Nygren committed
248
            if agent.speed_data['position_fraction'] == 0.:
249
250
251
252
                if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING:
                    cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
                        self._check_action_on_agent(action, agent)

Erik Nygren's avatar
Erik Nygren committed
253
                    if all([new_cell_isValid, transition_isValid]):
254
255
256
257
258
259
260
261
262
263
                        agent.speed_data['transition_action_on_cellexit'] = action
                        action_selected = True

                    else:
                        # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
                        # try to keep moving forward!
                        if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT) and agent.moving:
                            cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
                                self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)

Erik Nygren's avatar
Erik Nygren committed
264
                            if all([new_cell_isValid, transition_isValid]):
265
266
267
268
269
                                agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
                                action_selected = True

                            else:
                                # TODO: an invalid action was chosen after entering the cell. The agent cannot move.
u214892's avatar
u214892 committed
270
                                self.rewards_dict[i_agent] += invalid_action_penalty
271
                                agent.moving = False
u214892's avatar
u214892 committed
272
                                self.rewards_dict[i_agent] += stop_penalty
Erik Nygren's avatar
Erik Nygren committed
273

274
                                continue
275
                        else:
276
                            # TODO: an invalid action was chosen after entering the cell. The agent cannot move.
u214892's avatar
u214892 committed
277
                            self.rewards_dict[i_agent] += invalid_action_penalty
278
                            agent.moving = False
u214892's avatar
u214892 committed
279
                            self.rewards_dict[i_agent] += stop_penalty
Erik Nygren's avatar
Erik Nygren committed
280

281
                            continue
282

Erik Nygren's avatar
Erik Nygren committed
283
            if agent.moving and (action_selected or agent.speed_data['position_fraction'] > 0.0):
284
                agent.speed_data['position_fraction'] += agent.speed_data['speed']
285

286
            if agent.speed_data['position_fraction'] >= 1.0:
287

288
289
290
291
292
293
                # Perform stored action to transition to the next cell

                # Now 'transition_action_on_cellexit' will be guaranteed to be valid; it was checked on entering
                # the cell
                cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
                    self._check_action_on_agent(agent.speed_data['transition_action_on_cellexit'], agent)
294
295
296
297
298
299

                if all([new_cell_isValid, transition_isValid, cell_isFree]):
                    agent.position = new_position
                    agent.direction = new_direction
                    agent.speed_data['position_fraction'] = 0.0

spiglerg's avatar
spiglerg committed
300
            if np.equal(agent.position, agent.target).all():
u214892's avatar
u214892 committed
301
                self.dones[i_agent] = True
spiglerg's avatar
spiglerg committed
302
            else:
u214892's avatar
u214892 committed
303
                self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
spiglerg's avatar
spiglerg committed
304

305
        # Check for end of episode + add global reward to all rewards!
306
        if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
307
            self.dones["__all__"] = True
spmohanty's avatar
spmohanty committed
308
            self.rewards_dict = {i: 0 * r + global_reward for i, r in self.rewards_dict.items()}
309
310
311

        return self._get_observations(), self.rewards_dict, self.dones, {}

u214892's avatar
u214892 committed
312
313
314
315
316
    def _check_action_on_agent(self, action, agent):
        # compute number of possible transitions in the current
        # cell used to check for invalid actions
        new_direction, transition_isValid = self.check_action(agent, action)
        new_position = get_new_position(agent.position, new_direction)
317

u214892's avatar
u214892 committed
318
        # Is it a legal move?
spiglerg's avatar
spiglerg committed
319
320
321
322
323
324
325
326
327
        # 1) transition allows the new_direction in the cell,
        # 2) the new cell is not empty (case 0),
        # 3) the cell is free, i.e., no agent is currently in that cell
        new_cell_isValid = (
            np.array_equal(  # Check the new position is still in the grid
                new_position,
                np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
            and  # check the new position has some transitions (ie is not an empty cell)
            self.rail.get_transitions(new_position) > 0)
328

spiglerg's avatar
spiglerg committed
329
330
331
332
333
        # If transition validity hasn't been checked yet.
        if transition_isValid is None:
            transition_isValid = self.rail.get_transition(
                (*agent.position, agent.direction),
                new_direction)
334

spiglerg's avatar
spiglerg committed
335
336
337
338
339
340
        # Check the new position is not the same as any of the existing agent positions
        # (including itself, for simplicity, since it is moving)
        cell_isFree = not np.any(
            np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
        return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid

hagrid67's avatar
hagrid67 committed
341
342
343
344
345
346
    def check_action(self, agent, action):
        transition_isValid = None
        possible_transitions = self.rail.get_transitions((*agent.position, agent.direction))
        num_transitions = np.count_nonzero(possible_transitions)

        new_direction = agent.direction
spiglerg's avatar
spiglerg committed
347
        if action == RailEnvActions.MOVE_LEFT:
hagrid67's avatar
hagrid67 committed
348
349
350
351
            new_direction = agent.direction - 1
            if num_transitions <= 1:
                transition_isValid = False

spiglerg's avatar
spiglerg committed
352
        elif action == RailEnvActions.MOVE_RIGHT:
hagrid67's avatar
hagrid67 committed
353
354
355
356
357
358
            new_direction = agent.direction + 1
            if num_transitions <= 1:
                transition_isValid = False

        new_direction %= 4

spiglerg's avatar
spiglerg committed
359
        if action == RailEnvActions.MOVE_FORWARD:
hagrid67's avatar
hagrid67 committed
360
361
362
363
364
365
366
367
            if num_transitions == 1:
                # - dead-end, straight line or curved line;
                # new_direction will be the only valid transition
                # - take only available transition
                new_direction = np.argmax(possible_transitions)
                transition_isValid = True
        return new_direction, transition_isValid

368
    def _get_observations(self):
369
        self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
370
        return self.obs_dict
371

maljx's avatar
maljx committed
372
373
374
375
    def get_full_state_msg(self):
        grid_data = self.rail.grid.tolist()
        agent_static_data = [agent.to_list() for agent in self.agents_static]
        agent_data = [agent.to_list() for agent in self.agents]
376
377
378
379
380

        msgpack.packb(grid_data)
        msgpack.packb(agent_data)
        msgpack.packb(agent_static_data)

maljx's avatar
maljx committed
381
382
383
        msg_data = {
            "grid": grid_data,
            "agents_static": agent_static_data,
384
            "agents": agent_data}
maljx's avatar
maljx committed
385
386
387
388
389
        return msgpack.packb(msg_data, use_bin_type=True)

    def get_agent_state_msg(self):
        agent_data = [agent.to_list() for agent in self.agents]
        msg_data = {
390
            "agents": agent_data}
maljx's avatar
maljx committed
391
392
393
394
395
        return msgpack.packb(msg_data, use_bin_type=True)

    def set_full_state_msg(self, msg_data):
        data = msgpack.unpackb(msg_data, use_list=False)
        self.rail.grid = np.array(data[b"grid"])
spiglerg's avatar
fix?    
spiglerg committed
396
397
        # agents are always reset as not moving
        self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
maljx's avatar
maljx committed
398
399
400
        self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]]
        # setup with loaded data
        self.height, self.width = self.rail.grid.shape
401
402
        self.rail.height = self.height
        self.rail.width = self.width
maljx's avatar
maljx committed
403
404
405
406
407
408
409
410
411
412
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)

    def save(self, filename):
        with open(filename, "wb") as file_out:
            file_out.write(self.get_full_state_msg())

    def load(self, filename):
        with open(filename, "rb") as file_in:
            load_data = file_in.read()
            self.set_full_state_msg(load_data)
u214892's avatar
u214892 committed
413
414
415
416
417

    def load_resource(self, package, resource):
        from importlib_resources import read_binary
        load_data = read_binary(package, resource)
        self.set_full_state_msg(load_data)