rail_env.py 15.6 KB
Newer Older
1
2
3
4
5
6
"""
Definition of the RailEnv environment and related level-generation functions.

Generator functions are functions that take width, height and num_resets as arguments and return
a GridTransitionMap object.
"""
hagrid67's avatar
hagrid67 committed
7
# TODO:  _ this is a global method --> utils or remove later
8

9
10
from enum import IntEnum

maljx's avatar
maljx committed
11
import msgpack
12
import numpy as np
13
14

from flatland.core.env import Environment
15
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
16
17
18
19
from flatland.envs.env_utils import get_new_position
from flatland.envs.generators import random_rail_generator
from flatland.envs.observations import TreeObsForRailEnv

20

spiglerg's avatar
spiglerg committed
21
22
23
24
25
26
27
28
class RailEnvActions(IntEnum):
    DO_NOTHING = 0
    MOVE_LEFT = 1
    MOVE_FORWARD = 2
    MOVE_RIGHT = 3
    STOP_MOVING = 4


29
30
31
32
33
34
35
36
37
38
39
class RailEnv(Environment):
    """
    RailEnv environment class.

    RailEnv is an environment inspired by a (simplified version of) a rail
    network, in which agents (trains) have to navigate to their target
    locations in the shortest time possible, while at the same time cooperating
    to avoid bottlenecks.

    The valid actions in the environment are:
        0: do nothing
spiglerg's avatar
spiglerg committed
40
41
42
43
        1: turn left and move to the next cell; if the agent was not moving, movement is started
        2: move to the next cell in front of the agent; if the agent was not moving, movement is started
        3: turn right and move to the next cell; if the agent was not moving, movement is started
        4: stop moving
44
45
46
47
48
49
50
51
52
53
54
55
56
57

    Moving forward in a dead-end cell makes the agent turn 180 degrees and step
    to the cell it came from.

    The actions of the agents are executed in order of their handle to prevent
    deadlocks and to allow them to learn relative priorities.

    TODO: WRITE ABOUT THE REWARD FUNCTION, and possibly allow for alpha and
    beta to be passed as parameters to __init__().
    """

    def __init__(self,
                 width,
                 height,
spiglerg's avatar
spiglerg committed
58
                 rail_generator=random_rail_generator(),
59
                 number_of_agents=1,
u214892's avatar
u214892 committed
60
61
62
                 obs_builder_object=TreeObsForRailEnv(max_depth=2),
                 prediction_builder_object=None
                 ):
63
64
65
66
67
68
        """
        Environment init.

        Parameters
        -------
        rail_generator : function
69
70
71
72
            The rail_generator function is a function that takes the width,
            height and agents handles of a  rail environment, along with the number of times
            the env has been reset, and returns a GridTransitionMap object and a list of
            starting positions, targets, and initial orientations for agent handle.
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
            Implemented functions are:
                random_rail_generator : generate a random rail of given size
                rail_from_GridTransitionMap_generator(rail_map) : generate a rail from
                                        a GridTransitionMap object
                rail_from_manual_specifications_generator(rail_spec) : generate a rail from
                                        a rail specifications array
                TODO: generate_rail_from_saved_list or from list of ndarray bitmaps ---
        width : int
            The width of the rail map. Potentially in the future,
            a range of widths to sample from.
        height : int
            The height of the rail map. Potentially in the future,
            a range of heights to sample from.
        number_of_agents : int
            Number of agents to spawn on the map. Potentially in the future,
            a range of number of agents to sample from.
        obs_builder_object: ObservationBuilder object
            ObservationBuilder-derived object that takes builds observation
            vectors for each agent.
        """

        self.rail_generator = rail_generator
        self.rail = None
        self.width = width
        self.height = height

        self.obs_builder = obs_builder_object
        self.obs_builder._set_env(self)

u214892's avatar
u214892 committed
102
103
104
105
        self.prediction_builder = prediction_builder_object
        if self.prediction_builder:
            self.prediction_builder._set_env(self)

106
        self.action_space = [1]
spiglerg's avatar
spiglerg committed
107
        self.observation_space = self.obs_builder.observation_space  # updated on resets?
108

109
110
        self.actions = [0] * number_of_agents
        self.rewards = [0] * number_of_agents
111
112
        self.done = False

113
114
        self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False)

115
116
        self.obs_dict = {}
        self.rewards_dict = {}
Erik Nygren's avatar
Erik Nygren committed
117
        self.dev_obs_dict = {}
118

119
120
        self.agents = [None] * number_of_agents  # live agents
        self.agents_static = [None] * number_of_agents  # static agent information
121
122
        self.num_resets = 0
        self.reset()
123
        self.num_resets = 0  # yes, set it to zero again!
124

125
126
        self.valid_positions = None

127
    # no more agent_handles
128
    def get_agent_handles(self):
129
130
131
132
133
134
135
        return range(self.get_num_agents())

    def get_num_agents(self, static=True):
        if static:
            return len(self.agents_static)
        else:
            return len(self.agents)
136

hagrid67's avatar
hagrid67 committed
137
138
139
140
141
142
143
    def add_agent_static(self, agent_static):
        """ Add static info for a single agent.
            Returns the index of the new agent.
        """
        self.agents_static.append(agent_static)
        return len(self.agents_static) - 1

144
145
    def restart_agents(self):
        """ Reset the agents to their starting positions defined in agents_static
hagrid67's avatar
hagrid67 committed
146
        """
147
148
149
150
151
152
        self.agents = EnvAgent.list_from_static(self.agents_static)

    def reset(self, regen_rail=True, replace_agents=True):
        """ if regen_rail then regenerate the rails.
            if replace_agents then regenerate the agents static.
            Relies on the rail_generator returning agent_static lists (pos, dir, target)
hagrid67's avatar
hagrid67 committed
153
        """
154
        tRailAgents = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
hagrid67's avatar
hagrid67 committed
155

156
        if regen_rail or self.rail is None:
hagrid67's avatar
hagrid67 committed
157
            self.rail = tRailAgents[0]
158

hagrid67's avatar
hagrid67 committed
159
        if replace_agents:
hagrid67's avatar
hagrid67 committed
160
161
            self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:4])

162
        self.restart_agents()
hagrid67's avatar
hagrid67 committed
163

164
165
        self.num_resets += 1

u214892's avatar
u214892 committed
166
        # TODO perhaps dones should be part of each agent.
167
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
168

169
170
        # Reset the state of the observation builder with the new environment
        self.obs_builder.reset()
spiglerg's avatar
spiglerg committed
171
        self.observation_space = self.obs_builder.observation_space  # <-- change on reset?
172
173
174
175
176
177
178
179

        # Return the new observation vectors for each agent
        return self._get_observations()

    def step(self, action_dict):
        alpha = 1.0
        beta = 1.0

180
        invalid_action_penalty = 0  # previously -2; GIACOMO: we decided that invalid actions will carry no penalty
181
182
        step_penalty = -1 * alpha
        global_reward = 1 * beta
183
184
        stop_penalty = 0  # penalty for stopping a moving agent
        start_penalty = 0  # penalty for starting a stopped agent
185
186

        # Reset the step rewards
187
        self.rewards_dict = dict()
188
189
        for iAgent in range(self.get_num_agents()):
            self.rewards_dict[iAgent] = 0
190
191

        if self.dones["__all__"]:
192
            self.rewards_dict = [r + global_reward for r in self.rewards_dict]
193
194
            return self._get_observations(), self.rewards_dict, self.dones, {}

195
196
197
        # for i in range(len(self.agents_handles)):
        for iAgent in range(self.get_num_agents()):
            agent = self.agents[iAgent]
Erik Nygren's avatar
Erik Nygren committed
198

199
            if iAgent not in action_dict:  # no action has been supplied for this agent
spiglerg's avatar
spiglerg committed
200
201
                if agent.moving:
                    # Keep moving
202
                    # Change MOVE_FORWARD to DO_NOTHING
203
                    action_dict[iAgent] = RailEnvActions.DO_NOTHING
spiglerg's avatar
spiglerg committed
204
205
                else:
                    action_dict[iAgent] = RailEnvActions.DO_NOTHING
206

207
            if self.dones[iAgent]:  # this agent has already completed...
208
                continue
209
            action = action_dict[iAgent]
210

spiglerg's avatar
spiglerg committed
211
            if action < 0 or action > len(RailEnvActions):
212
                print('ERROR: illegal action=', action,
213
                      'for agent with index=', iAgent)
214
215
                return

spiglerg's avatar
spiglerg committed
216
217
218
219
220
221
            if action == RailEnvActions.DO_NOTHING and agent.moving:
                # Keep moving
                action = RailEnvActions.MOVE_FORWARD

            if action == RailEnvActions.STOP_MOVING and agent.moving:
                agent.moving = False
222
223
                self.rewards_dict[iAgent] += stop_penalty

224
225
            if not agent.moving and action == RailEnvActions.MOVE_FORWARD:
                # Only allow agent to start moving by pressing forward.
spiglerg's avatar
spiglerg committed
226
                agent.moving = True
227
                self.rewards_dict[iAgent] += start_penalty
spiglerg's avatar
spiglerg committed
228
229

            if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING:
u214892's avatar
u214892 committed
230
231
                cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
                    self._check_action_on_agent(action, agent)
232
                if all([new_cell_isValid, transition_isValid, cell_isFree]):
hagrid67's avatar
hagrid67 committed
233
                    agent.old_direction = agent.direction
234
235
                    agent.old_position = agent.position
                    agent.position = new_position
hagrid67's avatar
hagrid67 committed
236
                    agent.direction = new_direction
237
                else:
spiglerg's avatar
spiglerg committed
238
                    # Logic: if the chosen action is invalid,
239
                    # and it was LEFT or RIGHT, and the agent was moving, then keep moving FORWARD.
240
                    if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT) and agent.moving:
spiglerg's avatar
spiglerg committed
241
242
                        cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
                            self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)
243
244
245
246
247
248
249
250
251
252
253
254
255

                        if all([new_cell_isValid, transition_isValid, cell_isFree]):
                            agent.old_direction = agent.direction
                            agent.old_position = agent.position
                            agent.position = new_position
                            agent.direction = new_direction
                        else:
                            # the action was not valid, add penalty
                            self.rewards_dict[iAgent] += invalid_action_penalty

                    else:
                        # the action was not valid, add penalty
                        self.rewards_dict[iAgent] += invalid_action_penalty
256

257
            if np.equal(agent.position, agent.target).all():
258
                self.dones[iAgent] = True
259
            else:
260
                self.rewards_dict[iAgent] += step_penalty
261
262

        # Check for end of episode + add global reward to all rewards!
263
        if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
264
            self.dones["__all__"] = True
265
            self.rewards_dict = [0 * r + global_reward for r in self.rewards_dict]
266
267
268

        # Reset the step actions (in case some agent doesn't 'register_action'
        # on the next step)
269
        self.actions = [0] * self.get_num_agents()
270
271
        return self._get_observations(), self.rewards_dict, self.dones, {}

u214892's avatar
u214892 committed
272
273
274
275
276
277
    def _check_action_on_agent(self, action, agent):
        # compute number of possible transitions in the current
        # cell used to check for invalid actions
        new_direction, transition_isValid = self.check_action(agent, action)
        new_position = get_new_position(agent.position, new_direction)
        # Is it a legal move?
spiglerg's avatar
spiglerg committed
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
        # 1) transition allows the new_direction in the cell,
        # 2) the new cell is not empty (case 0),
        # 3) the cell is free, i.e., no agent is currently in that cell
        new_cell_isValid = (
            np.array_equal(  # Check the new position is still in the grid
                new_position,
                np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
            and  # check the new position has some transitions (ie is not an empty cell)
            self.rail.get_transitions(new_position) > 0)
        # If transition validity hasn't been checked yet.
        if transition_isValid is None:
            transition_isValid = self.rail.get_transition(
                (*agent.position, agent.direction),
                new_direction)
        # Check the new position is not the same as any of the existing agent positions
        # (including itself, for simplicity, since it is moving)
        cell_isFree = not np.any(
            np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
        return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid

u214892's avatar
u214892 committed
298
299
300
    def predict(self):
        if not self.prediction_builder:
            return {}
u214892's avatar
u214892 committed
301
        return self.prediction_builder.get()
u214892's avatar
u214892 committed
302

hagrid67's avatar
hagrid67 committed
303
304
305
306
307
308
    def check_action(self, agent, action):
        transition_isValid = None
        possible_transitions = self.rail.get_transitions((*agent.position, agent.direction))
        num_transitions = np.count_nonzero(possible_transitions)

        new_direction = agent.direction
spiglerg's avatar
spiglerg committed
309
        if action == RailEnvActions.MOVE_LEFT:
hagrid67's avatar
hagrid67 committed
310
311
312
313
            new_direction = agent.direction - 1
            if num_transitions <= 1:
                transition_isValid = False

spiglerg's avatar
spiglerg committed
314
        elif action == RailEnvActions.MOVE_RIGHT:
hagrid67's avatar
hagrid67 committed
315
316
317
318
319
320
            new_direction = agent.direction + 1
            if num_transitions <= 1:
                transition_isValid = False

        new_direction %= 4

spiglerg's avatar
spiglerg committed
321
        if action == RailEnvActions.MOVE_FORWARD:
hagrid67's avatar
hagrid67 committed
322
323
324
325
326
327
328
329
            if num_transitions == 1:
                # - dead-end, straight line or curved line;
                # new_direction will be the only valid transition
                # - take only available transition
                new_direction = np.argmax(possible_transitions)
                transition_isValid = True
        return new_direction, transition_isValid

330
331
    def _get_observations(self):
        self.obs_dict = {}
332
        self.debug_obs_dict = {}
333
        for iAgent in range(self.get_num_agents()):
334
335
            self.obs_dict[iAgent] = self.obs_builder.get(iAgent)
        return self.obs_dict
336

u214892's avatar
u214892 committed
337
338
339
340
341
    def _get_predictions(self):
        if not self.prediction_builder:
            return {}
        return {}

342
343
344
    def render(self):
        # TODO:
        pass
345

maljx's avatar
maljx committed
346
347
348
349
    def get_full_state_msg(self):
        grid_data = self.rail.grid.tolist()
        agent_static_data = [agent.to_list() for agent in self.agents_static]
        agent_data = [agent.to_list() for agent in self.agents]
350
351
352
353
354

        msgpack.packb(grid_data)
        msgpack.packb(agent_data)
        msgpack.packb(agent_static_data)

maljx's avatar
maljx committed
355
356
357
        msg_data = {
            "grid": grid_data,
            "agents_static": agent_static_data,
358
            "agents": agent_data}
maljx's avatar
maljx committed
359
360
361
362
363
        return msgpack.packb(msg_data, use_bin_type=True)

    def get_agent_state_msg(self):
        agent_data = [agent.to_list() for agent in self.agents]
        msg_data = {
364
            "agents": agent_data}
maljx's avatar
maljx committed
365
366
367
368
369
        return msgpack.packb(msg_data, use_bin_type=True)

    def set_full_state_msg(self, msg_data):
        data = msgpack.unpackb(msg_data, use_list=False)
        self.rail.grid = np.array(data[b"grid"])
spiglerg's avatar
fix?    
spiglerg committed
370
371
        # agents are always reset as not moving
        self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
maljx's avatar
maljx committed
372
373
374
        self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]]
        # setup with loaded data
        self.height, self.width = self.rail.grid.shape
375
376
        self.rail.height = self.height
        self.rail.width = self.width
maljx's avatar
maljx committed
377
378
379
380
381
382
383
384
385
386
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)

    def save(self, filename):
        with open(filename, "wb") as file_out:
            file_out.write(self.get_full_state_msg())

    def load(self, filename):
        with open(filename, "rb") as file_in:
            load_data = file_in.read()
            self.set_full_state_msg(load_data)