rail_env.py 16.7 KB
Newer Older
1
2
3
4
5
6
"""
Definition of the RailEnv environment and related level-generation functions.

Generator functions are functions that take width, height and num_resets as arguments and return
a GridTransitionMap object.
"""
hagrid67's avatar
hagrid67 committed
7
# TODO:  _ this is a global method --> utils or remove later
8

9
10
from enum import IntEnum

maljx's avatar
maljx committed
11
import msgpack
12
import numpy as np
13
14

from flatland.core.env import Environment
15
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
16
17
18
19
from flatland.envs.env_utils import get_new_position
from flatland.envs.generators import random_rail_generator
from flatland.envs.observations import TreeObsForRailEnv

20

spiglerg's avatar
spiglerg committed
21
22
23
24
25
26
27
class RailEnvActions(IntEnum):
    DO_NOTHING = 0
    MOVE_LEFT = 1
    MOVE_FORWARD = 2
    MOVE_RIGHT = 3
    STOP_MOVING = 4

u214892's avatar
u214892 committed
28

29
30
31
32
33
34
35
36
37
38
39
class RailEnv(Environment):
    """
    RailEnv environment class.

    RailEnv is an environment inspired by a (simplified version of) a rail
    network, in which agents (trains) have to navigate to their target
    locations in the shortest time possible, while at the same time cooperating
    to avoid bottlenecks.

    The valid actions in the environment are:
        0: do nothing
spiglerg's avatar
spiglerg committed
40
41
42
43
        1: turn left and move to the next cell; if the agent was not moving, movement is started
        2: move to the next cell in front of the agent; if the agent was not moving, movement is started
        3: turn right and move to the next cell; if the agent was not moving, movement is started
        4: stop moving
44
45
46
47
48
49
50
51
52
53
54
55
56
57

    Moving forward in a dead-end cell makes the agent turn 180 degrees and step
    to the cell it came from.

    The actions of the agents are executed in order of their handle to prevent
    deadlocks and to allow them to learn relative priorities.

    TODO: WRITE ABOUT THE REWARD FUNCTION, and possibly allow for alpha and
    beta to be passed as parameters to __init__().
    """

    def __init__(self,
                 width,
                 height,
spiglerg's avatar
spiglerg committed
58
                 rail_generator=random_rail_generator(),
59
                 number_of_agents=1,
u214892's avatar
u214892 committed
60
61
                 obs_builder_object=TreeObsForRailEnv(max_depth=2),
                 ):
62
63
64
65
66
67
        """
        Environment init.

        Parameters
        -------
        rail_generator : function
68
69
70
71
            The rail_generator function is a function that takes the width,
            height and agents handles of a  rail environment, along with the number of times
            the env has been reset, and returns a GridTransitionMap object and a list of
            starting positions, targets, and initial orientations for agent handle.
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
            Implemented functions are:
                random_rail_generator : generate a random rail of given size
                rail_from_GridTransitionMap_generator(rail_map) : generate a rail from
                                        a GridTransitionMap object
                rail_from_manual_specifications_generator(rail_spec) : generate a rail from
                                        a rail specifications array
                TODO: generate_rail_from_saved_list or from list of ndarray bitmaps ---
        width : int
            The width of the rail map. Potentially in the future,
            a range of widths to sample from.
        height : int
            The height of the rail map. Potentially in the future,
            a range of heights to sample from.
        number_of_agents : int
            Number of agents to spawn on the map. Potentially in the future,
            a range of number of agents to sample from.
        obs_builder_object: ObservationBuilder object
            ObservationBuilder-derived object that takes builds observation
            vectors for each agent.
        """

        self.rail_generator = rail_generator
        self.rail = None
        self.width = width
        self.height = height

        self.obs_builder = obs_builder_object
        self.obs_builder._set_env(self)

101
        self.action_space = [1]
spiglerg's avatar
spiglerg committed
102
        self.observation_space = self.obs_builder.observation_space  # updated on resets?
103

104
        self.rewards = [0] * number_of_agents
105
106
        self.done = False

107
108
        self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False)

109
110
        self.obs_dict = {}
        self.rewards_dict = {}
Erik Nygren's avatar
Erik Nygren committed
111
        self.dev_obs_dict = {}
112

113
114
        self.agents = [None] * number_of_agents  # live agents
        self.agents_static = [None] * number_of_agents  # static agent information
115
116
        self.num_resets = 0
        self.reset()
117
        self.num_resets = 0  # yes, set it to zero again!
118

119
120
        self.valid_positions = None

121
    # no more agent_handles
122
    def get_agent_handles(self):
123
124
125
126
127
128
129
        return range(self.get_num_agents())

    def get_num_agents(self, static=True):
        if static:
            return len(self.agents_static)
        else:
            return len(self.agents)
130

hagrid67's avatar
hagrid67 committed
131
132
133
134
135
136
137
    def add_agent_static(self, agent_static):
        """ Add static info for a single agent.
            Returns the index of the new agent.
        """
        self.agents_static.append(agent_static)
        return len(self.agents_static) - 1

138
139
    def restart_agents(self):
        """ Reset the agents to their starting positions defined in agents_static
hagrid67's avatar
hagrid67 committed
140
        """
141
142
143
144
145
146
        self.agents = EnvAgent.list_from_static(self.agents_static)

    def reset(self, regen_rail=True, replace_agents=True):
        """ if regen_rail then regenerate the rails.
            if replace_agents then regenerate the agents static.
            Relies on the rail_generator returning agent_static lists (pos, dir, target)
hagrid67's avatar
hagrid67 committed
147
        """
148
        tRailAgents = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
hagrid67's avatar
hagrid67 committed
149

150
        if regen_rail or self.rail is None:
hagrid67's avatar
hagrid67 committed
151
            self.rail = tRailAgents[0]
152

hagrid67's avatar
hagrid67 committed
153
        if replace_agents:
hagrid67's avatar
hagrid67 committed
154
155
            self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:4])

156
        self.restart_agents()
hagrid67's avatar
hagrid67 committed
157

158
159
        self.num_resets += 1

u214892's avatar
u214892 committed
160
        # TODO perhaps dones should be part of each agent.
161
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
162

163
164
        # Reset the state of the observation builder with the new environment
        self.obs_builder.reset()
spiglerg's avatar
spiglerg committed
165
        self.observation_space = self.obs_builder.observation_space  # <-- change on reset?
166
167
168
169

        # Return the new observation vectors for each agent
        return self._get_observations()

spiglerg's avatar
spiglerg committed
170
171
172
    def step(self, action_dict_):
        action_dict = action_dict_.copy()

173
174
175
        alpha = 1.0
        beta = 1.0

176
        invalid_action_penalty = 0  # previously -2; GIACOMO: we decided that invalid actions will carry no penalty
177
178
        step_penalty = -1 * alpha
        global_reward = 1 * beta
179
180
        stop_penalty = 0  # penalty for stopping a moving agent
        start_penalty = 0  # penalty for starting a stopped agent
181
182

        # Reset the step rewards
183
        self.rewards_dict = dict()
184
185
        for iAgent in range(self.get_num_agents()):
            self.rewards_dict[iAgent] = 0
186
187

        if self.dones["__all__"]:
188
            self.rewards_dict = [r + global_reward for r in self.rewards_dict]
189
190
            return self._get_observations(), self.rewards_dict, self.dones, {}

191
192
193
        # for i in range(len(self.agents_handles)):
        for iAgent in range(self.get_num_agents()):
            agent = self.agents[iAgent]
Erik Nygren's avatar
Erik Nygren committed
194

195
            if self.dones[iAgent]:  # this agent has already completed...
196
                continue
197

198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
            if np.equal(agent.position, agent.target).all():
                self.dones[iAgent] = True
            else:
                self.rewards_dict[iAgent] += step_penalty

            if iAgent not in action_dict:  # no action has been supplied for this agent
                action_dict[iAgent] = RailEnvActions.DO_NOTHING

            if action_dict[iAgent] < 0 or action_dict[iAgent] > len(RailEnvActions):
                print('ERROR: illegal action=', action_dict[iAgent],
                      'for agent with index=', iAgent,
                      '"DO NOTHING" will be executed instead')
                action_dict[iAgent] = RailEnvActions.DO_NOTHING

            action = action_dict[iAgent]
213

spiglerg's avatar
spiglerg committed
214
215
216
217
218
219
            if action == RailEnvActions.DO_NOTHING and agent.moving:
                # Keep moving
                action = RailEnvActions.MOVE_FORWARD

            if action == RailEnvActions.STOP_MOVING and agent.moving:
                agent.moving = False
220
221
                self.rewards_dict[iAgent] += stop_penalty

222
            if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
223
                # Only allow agent to start moving by pressing forward.
spiglerg's avatar
spiglerg committed
224
                agent.moving = True
225
                self.rewards_dict[iAgent] += start_penalty
spiglerg's avatar
spiglerg committed
226
227

            if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING:
228
229
                # Now perform a movement.
                # If the agent is in an initial position within a new cell (agent.speed_data['position_fraction']<eps)
spiglerg's avatar
spiglerg committed
230
231
232
233
234
235
236
                #   store the desired action in `transition_action_on_cellexit' (only if the desired transition is
                #   allowed! otherwise DO_NOTHING!)
                # Then in any case (if agent.moving) and the `transition_action_on_cellexit' is valid, increment the
                #   position_fraction by the speed of the agent   (regardless of action taken, as long as no
                #   STOP_MOVING, but that makes agent.moving=False)
                # If the new position fraction is >= 1, reset to 0, and perform the stored
                #   transition_action_on_cellexit
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260

                if agent.speed_data['position_fraction'] < 0.01:
                    # Is the desired transition valid?

                    cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
                        self._check_action_on_agent(action, agent)

                    if all([new_cell_isValid, transition_isValid, cell_isFree]):
                        agent.speed_data['transition_action_on_cellexit'] = action

                    else:
                        # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
                        # try to keep moving forward!
                        if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT) and agent.moving:
                            cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
                                self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)

                            if all([new_cell_isValid, transition_isValid, cell_isFree]):
                                agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD

                            else:
                                # TODO: an invalid action was chosen after entering the cell. The agent cannot move.
                                self.rewards_dict[iAgent] += invalid_action_penalty
                                continue
261
                        else:
262
                            # TODO: an invalid action was chosen after entering the cell. The agent cannot move.
263
                            self.rewards_dict[iAgent] += invalid_action_penalty
264
                            continue
265

266
                agent.speed_data['position_fraction'] += agent.speed_data['speed']
267

268
269
270
271
272
                if agent.speed_data['position_fraction'] >= 1.0:
                    agent.speed_data['position_fraction'] = 0.0

                    # Perform stored action to transition to the next cell

spiglerg's avatar
spiglerg committed
273
274
                    # Now 'transition_action_on_cellexit' will be guaranteed to be valid; it was checked on entering
                    # the cell
275
276
277
278
279
280
                    cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
                        self._check_action_on_agent(agent.speed_data['transition_action_on_cellexit'], agent)
                    agent.old_direction = agent.direction
                    agent.old_position = agent.position
                    agent.position = new_position
                    agent.direction = new_direction
281
282

        # Check for end of episode + add global reward to all rewards!
283
        if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
284
            self.dones["__all__"] = True
285
            self.rewards_dict = [0 * r + global_reward for r in self.rewards_dict]
286
287
288

        return self._get_observations(), self.rewards_dict, self.dones, {}

u214892's avatar
u214892 committed
289
290
291
292
293
    def _check_action_on_agent(self, action, agent):
        # compute number of possible transitions in the current
        # cell used to check for invalid actions
        new_direction, transition_isValid = self.check_action(agent, action)
        new_position = get_new_position(agent.position, new_direction)
294

u214892's avatar
u214892 committed
295
        # Is it a legal move?
spiglerg's avatar
spiglerg committed
296
297
298
299
300
301
302
303
304
        # 1) transition allows the new_direction in the cell,
        # 2) the new cell is not empty (case 0),
        # 3) the cell is free, i.e., no agent is currently in that cell
        new_cell_isValid = (
            np.array_equal(  # Check the new position is still in the grid
                new_position,
                np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
            and  # check the new position has some transitions (ie is not an empty cell)
            self.rail.get_transitions(new_position) > 0)
305

spiglerg's avatar
spiglerg committed
306
307
308
309
310
        # If transition validity hasn't been checked yet.
        if transition_isValid is None:
            transition_isValid = self.rail.get_transition(
                (*agent.position, agent.direction),
                new_direction)
311

spiglerg's avatar
spiglerg committed
312
313
314
315
316
317
        # Check the new position is not the same as any of the existing agent positions
        # (including itself, for simplicity, since it is moving)
        cell_isFree = not np.any(
            np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
        return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid

hagrid67's avatar
hagrid67 committed
318
319
320
321
322
323
    def check_action(self, agent, action):
        transition_isValid = None
        possible_transitions = self.rail.get_transitions((*agent.position, agent.direction))
        num_transitions = np.count_nonzero(possible_transitions)

        new_direction = agent.direction
spiglerg's avatar
spiglerg committed
324
        if action == RailEnvActions.MOVE_LEFT:
hagrid67's avatar
hagrid67 committed
325
326
327
328
            new_direction = agent.direction - 1
            if num_transitions <= 1:
                transition_isValid = False

spiglerg's avatar
spiglerg committed
329
        elif action == RailEnvActions.MOVE_RIGHT:
hagrid67's avatar
hagrid67 committed
330
331
332
333
334
335
            new_direction = agent.direction + 1
            if num_transitions <= 1:
                transition_isValid = False

        new_direction %= 4

spiglerg's avatar
spiglerg committed
336
        if action == RailEnvActions.MOVE_FORWARD:
hagrid67's avatar
hagrid67 committed
337
338
339
340
341
342
343
344
            if num_transitions == 1:
                # - dead-end, straight line or curved line;
                # new_direction will be the only valid transition
                # - take only available transition
                new_direction = np.argmax(possible_transitions)
                transition_isValid = True
        return new_direction, transition_isValid

345
    def _get_observations(self):
346
        self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
347
        return self.obs_dict
348

maljx's avatar
maljx committed
349
350
351
352
    def get_full_state_msg(self):
        grid_data = self.rail.grid.tolist()
        agent_static_data = [agent.to_list() for agent in self.agents_static]
        agent_data = [agent.to_list() for agent in self.agents]
353
354
355
356
357

        msgpack.packb(grid_data)
        msgpack.packb(agent_data)
        msgpack.packb(agent_static_data)

maljx's avatar
maljx committed
358
359
360
        msg_data = {
            "grid": grid_data,
            "agents_static": agent_static_data,
361
            "agents": agent_data}
maljx's avatar
maljx committed
362
363
364
365
366
        return msgpack.packb(msg_data, use_bin_type=True)

    def get_agent_state_msg(self):
        agent_data = [agent.to_list() for agent in self.agents]
        msg_data = {
367
            "agents": agent_data}
maljx's avatar
maljx committed
368
369
370
371
372
        return msgpack.packb(msg_data, use_bin_type=True)

    def set_full_state_msg(self, msg_data):
        data = msgpack.unpackb(msg_data, use_list=False)
        self.rail.grid = np.array(data[b"grid"])
spiglerg's avatar
fix?    
spiglerg committed
373
374
        # agents are always reset as not moving
        self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
maljx's avatar
maljx committed
375
376
377
        self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]]
        # setup with loaded data
        self.height, self.width = self.rail.grid.shape
378
379
        self.rail.height = self.height
        self.rail.width = self.width
maljx's avatar
maljx committed
380
381
382
383
384
385
386
387
388
389
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)

    def save(self, filename):
        with open(filename, "wb") as file_out:
            file_out.write(self.get_full_state_msg())

    def load(self, filename):
        with open(filename, "rb") as file_in:
            load_data = file_in.read()
            self.set_full_state_msg(load_data)
u214892's avatar
u214892 committed
390
391
392
393
394

    def load_resource(self, package, resource):
        from importlib_resources import read_binary
        load_data = read_binary(package, resource)
        self.set_full_state_msg(load_data)