rail_env.py 16.8 KB
Newer Older
1
2
3
4
5
6
"""
Definition of the RailEnv environment and related level-generation functions.

Generator functions are functions that take width, height and num_resets as arguments and return
a GridTransitionMap object.
"""
hagrid67's avatar
hagrid67 committed
7
8
# TODO:  _ this is a global method --> utils or remove later
# from inspect import currentframe
9

maljx's avatar
maljx committed
10
import msgpack
11
import numpy as np
spiglerg's avatar
spiglerg committed
12
from enum import IntEnum
13
14

from flatland.core.env import Environment
15
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
16
17
18
19
from flatland.envs.env_utils import get_new_position
from flatland.envs.generators import random_rail_generator
from flatland.envs.observations import TreeObsForRailEnv

20

21
22
# from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
# from flatland.core.transition_map import GridTransitionMap
23
24


spiglerg's avatar
spiglerg committed
25
26
27
28
29
30
31
32
class RailEnvActions(IntEnum):
    DO_NOTHING = 0
    MOVE_LEFT = 1
    MOVE_FORWARD = 2
    MOVE_RIGHT = 3
    STOP_MOVING = 4


33
34
35
36
37
38
39
40
41
42
43
class RailEnv(Environment):
    """
    RailEnv environment class.

    RailEnv is an environment inspired by a (simplified version of) a rail
    network, in which agents (trains) have to navigate to their target
    locations in the shortest time possible, while at the same time cooperating
    to avoid bottlenecks.

    The valid actions in the environment are:
        0: do nothing
spiglerg's avatar
spiglerg committed
44
45
46
47
        1: turn left and move to the next cell; if the agent was not moving, movement is started
        2: move to the next cell in front of the agent; if the agent was not moving, movement is started
        3: turn right and move to the next cell; if the agent was not moving, movement is started
        4: stop moving
48
49
50
51
52
53
54
55
56
57
58
59
60
61

    Moving forward in a dead-end cell makes the agent turn 180 degrees and step
    to the cell it came from.

    The actions of the agents are executed in order of their handle to prevent
    deadlocks and to allow them to learn relative priorities.

    TODO: WRITE ABOUT THE REWARD FUNCTION, and possibly allow for alpha and
    beta to be passed as parameters to __init__().
    """

    def __init__(self,
                 width,
                 height,
spiglerg's avatar
spiglerg committed
62
                 rail_generator=random_rail_generator(),
63
64
65
66
67
68
69
70
                 number_of_agents=1,
                 obs_builder_object=TreeObsForRailEnv(max_depth=2)):
        """
        Environment init.

        Parameters
        -------
        rail_generator : function
71
72
73
74
            The rail_generator function is a function that takes the width,
            height and agents handles of a  rail environment, along with the number of times
            the env has been reset, and returns a GridTransitionMap object and a list of
            starting positions, targets, and initial orientations for agent handle.
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
            Implemented functions are:
                random_rail_generator : generate a random rail of given size
                rail_from_GridTransitionMap_generator(rail_map) : generate a rail from
                                        a GridTransitionMap object
                rail_from_manual_specifications_generator(rail_spec) : generate a rail from
                                        a rail specifications array
                TODO: generate_rail_from_saved_list or from list of ndarray bitmaps ---
        width : int
            The width of the rail map. Potentially in the future,
            a range of widths to sample from.
        height : int
            The height of the rail map. Potentially in the future,
            a range of heights to sample from.
        number_of_agents : int
            Number of agents to spawn on the map. Potentially in the future,
            a range of number of agents to sample from.
        obs_builder_object: ObservationBuilder object
            ObservationBuilder-derived object that takes builds observation
            vectors for each agent.
        """

        self.rail_generator = rail_generator
        self.rail = None
        self.width = width
        self.height = height

101
102
        # use get_num_agents() instead
        # self.number_of_agents = number_of_agents
103
104
105
106

        self.obs_builder = obs_builder_object
        self.obs_builder._set_env(self)

107
        self.action_space = [1]
spiglerg's avatar
spiglerg committed
108
        self.observation_space = self.obs_builder.observation_space  # updated on resets?
109

110
111
        self.actions = [0] * number_of_agents
        self.rewards = [0] * number_of_agents
112
113
        self.done = False

114
115
        self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False)

116
117
        self.obs_dict = {}
        self.rewards_dict = {}
Erik Nygren's avatar
Erik Nygren committed
118
        self.dev_obs_dict = {}
119
        # self.agents_handles = list(range(self.number_of_agents))
120
121
122
123

        # self.agents_position = []
        # self.agents_target = []
        # self.agents_direction = []
124
125
        self.agents = [None] * number_of_agents  # live agents
        self.agents_static = [None] * number_of_agents  # static agent information
126
127
        self.num_resets = 0
        self.reset()
128
        self.num_resets = 0  # yes, set it to zero again!
129

130
131
        self.valid_positions = None

132
    # no more agent_handles
133
    def get_agent_handles(self):
134
135
136
137
138
139
140
        return range(self.get_num_agents())

    def get_num_agents(self, static=True):
        if static:
            return len(self.agents_static)
        else:
            return len(self.agents)
141

hagrid67's avatar
hagrid67 committed
142
143
144
145
146
147
148
    def add_agent_static(self, agent_static):
        """ Add static info for a single agent.
            Returns the index of the new agent.
        """
        self.agents_static.append(agent_static)
        return len(self.agents_static) - 1

149
150
    def restart_agents(self):
        """ Reset the agents to their starting positions defined in agents_static
hagrid67's avatar
hagrid67 committed
151
        """
152
153
154
155
156
157
        self.agents = EnvAgent.list_from_static(self.agents_static)

    def reset(self, regen_rail=True, replace_agents=True):
        """ if regen_rail then regenerate the rails.
            if replace_agents then regenerate the agents static.
            Relies on the rail_generator returning agent_static lists (pos, dir, target)
hagrid67's avatar
hagrid67 committed
158
        """
159
        tRailAgents = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
hagrid67's avatar
hagrid67 committed
160

161
        if regen_rail or self.rail is None:
hagrid67's avatar
hagrid67 committed
162
            self.rail = tRailAgents[0]
163

hagrid67's avatar
hagrid67 committed
164
        if replace_agents:
hagrid67's avatar
hagrid67 committed
165
166
167
            self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:4])

        # Take the agent static info and put (live) agents at the start positions
168
169
        # self.agents = EnvAgent.list_from_static(self.agents_static[:len(self.agents_handles)])
        self.restart_agents()
hagrid67's avatar
hagrid67 committed
170

171
172
        self.num_resets += 1

173
174
175
        # for handle in self.agents_handles:
        #    self.dones[handle] = False
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
176
        # perhaps dones should be part of each agent.
177

178
179
        # Reset the state of the observation builder with the new environment
        self.obs_builder.reset()
spiglerg's avatar
spiglerg committed
180
        self.observation_space = self.obs_builder.observation_space  # <-- change on reset?
181
182
183
184
185
186
187
188

        # Return the new observation vectors for each agent
        return self._get_observations()

    def step(self, action_dict):
        alpha = 1.0
        beta = 1.0

spiglerg's avatar
spiglerg committed
189
        invalid_action_penalty = 0 # -2 GIACOMO: we decided that invalid actions will carry no penalty
190
191
192
193
        step_penalty = -1 * alpha
        global_reward = 1 * beta

        # Reset the step rewards
194
        self.rewards_dict = dict()
195
196
197
198
        # for handle in self.agents_handles:
        #    self.rewards_dict[handle] = 0
        for iAgent in range(self.get_num_agents()):
            self.rewards_dict[iAgent] = 0
199
200

        if self.dones["__all__"]:
201
            self.rewards_dict = [r + global_reward for r in self.rewards_dict]
202
203
            return self._get_observations(), self.rewards_dict, self.dones, {}

204
205
206
        # for i in range(len(self.agents_handles)):
        for iAgent in range(self.get_num_agents()):
            # handle = self.agents_handles[i]
207
            transition_isValid = None
208
            agent = self.agents[iAgent]
Erik Nygren's avatar
Erik Nygren committed
209

210
            if iAgent not in action_dict:  # no action has been supplied for this agent
spiglerg's avatar
spiglerg committed
211
212
213
214
215
                if agent.moving:
                    # Keep moving
                    action_dict[iAgent] = RailEnvActions.MOVE_FORWARD
                else:
                    action_dict[iAgent] = RailEnvActions.DO_NOTHING
216

217
            if self.dones[iAgent]:  # this agent has already completed...
Egli Adrian (IT-SCI-API-PFI)'s avatar
...    
Egli Adrian (IT-SCI-API-PFI) committed
218
219
                # print("rail_env.py @", currentframe().f_back.f_lineno, " agent ", iAgent,
                #       "has already completed : why action will not be executed!!!!? ADRIAN")
220
                continue
221
            action = action_dict[iAgent]
222

spiglerg's avatar
spiglerg committed
223
            if action < 0 or action > len(RailEnvActions):
224
                print('ERROR: illegal action=', action,
225
                      'for agent with index=', iAgent)
226
227
                return

spiglerg's avatar
spiglerg committed
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
            if action == RailEnvActions.DO_NOTHING and agent.moving:
                # Keep moving
                action_dict[iAgent] = RailEnvActions.MOVE_FORWARD
                action = RailEnvActions.MOVE_FORWARD

            if action == RailEnvActions.STOP_MOVING and agent.moving:
                action_dict[iAgent] = RailEnvActions.DO_NOTHING
                action = RailEnvActions.DO_NOTHING
                agent.moving = False
                # TODO: possibly, penalty for stopping!

            if not agent.moving and (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_FORWARD or action == RailEnvActions.MOVE_RIGHT):
                agent.moving = True
                # TODO: possibly, may add a penalty for starting, but the best is only for stopping (GIACOMO's opinion)

            if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING:
244
245
                # pos = agent.position #  self.agents_position[i]
                # direction = agent.direction # self.agents_direction[i]
246

Erik Nygren's avatar
Erik Nygren committed
247
248
249
                # compute number of possible transitions in the current
                # cell used to check for invalid actions

hagrid67's avatar
hagrid67 committed
250
251
252
                new_direction, transition_isValid = self.check_action(agent, action)

                new_position = get_new_position(agent.position, new_direction)
253
                # Is it a legal move?
hagrid67's avatar
hagrid67 committed
254
                # 1) transition allows the new_direction in the cell,
255
256
                # 2) the new cell is not empty (case 0),
                # 3) the cell is free, i.e., no agent is currently in that cell
257

258
259
260
261
262
263
264
265
266
267
                # if (
                #        new_position[1] >= self.width or
                #        new_position[0] >= self.height or
                #        new_position[0] < 0 or new_position[1] < 0):
                #    new_cell_isValid = False

                # if self.rail.get_transitions(new_position) == 0:
                #     new_cell_isValid = False

                new_cell_isValid = (
268
269
270
271
272
                    np.array_equal(  # Check the new position is still in the grid
                        new_position,
                        np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
                    and  # check the new position has some transitions (ie is not an empty cell)
                    self.rail.get_transitions(new_position) > 0)
273

Erik Nygren's avatar
Erik Nygren committed
274
                # If transition validity hasn't been checked yet.
hagrid67's avatar
hagrid67 committed
275
                if transition_isValid is None:
276
                    transition_isValid = self.rail.get_transition(
277
                        (*agent.position, agent.direction),
hagrid67's avatar
hagrid67 committed
278
                        new_direction)
279

280
281
282
283
284
285
286
287
                # cell_isFree = True
                # for j in range(self.number_of_agents):
                #    if self.agents_position[j] == new_position:
                #        cell_isFree = False
                #        break
                # Check the new position is not the same as any of the existing agent positions
                # (including itself, for simplicity, since it is moving)
                cell_isFree = not np.any(
288
                    np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
289
290

                if all([new_cell_isValid, transition_isValid, cell_isFree]):
hagrid67's avatar
hagrid67 committed
291
                    # move and change direction to face the new_direction that was
292
                    # performed
293
                    # self.agents_position[i] = new_position
hagrid67's avatar
hagrid67 committed
294
295
                    # self.agents_direction[i] = new_direction
                    agent.old_direction = agent.direction
296
297
                    agent.old_position = agent.position
                    agent.position = new_position
hagrid67's avatar
hagrid67 committed
298
                    agent.direction = new_direction
299
300
                else:
                    # the action was not valid, add penalty
301
                    self.rewards_dict[iAgent] += invalid_action_penalty
302
303

            # if agent is not in target position, add step penalty
304
305
306
307
            # if self.agents_position[i][0] == self.agents_target[i][0] and \
            #        self.agents_position[i][1] == self.agents_target[i][1]:
            #    self.dones[handle] = True
            if np.equal(agent.position, agent.target).all():
308
                self.dones[iAgent] = True
309
            else:
310
                self.rewards_dict[iAgent] += step_penalty
311
312

        # Check for end of episode + add global reward to all rewards!
313
314
315
316
317
318
319
        # num_agents_in_target_position = 0
        # for i in range(self.number_of_agents):
        #    if self.agents_position[i][0] == self.agents_target[i][0] and \
        #            self.agents_position[i][1] == self.agents_target[i][1]:
        #        num_agents_in_target_position += 1
        # if num_agents_in_target_position == self.number_of_agents:
        if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
320
            self.dones["__all__"] = True
321
            self.rewards_dict = [0 * r + global_reward for r in self.rewards_dict]
322
323
324

        # Reset the step actions (in case some agent doesn't 'register_action'
        # on the next step)
spiglerg's avatar
spiglerg committed
325
        self.actions = [RailEnvActions.DO_NOTHING] * self.get_num_agents()
326
327
        return self._get_observations(), self.rewards_dict, self.dones, {}

hagrid67's avatar
hagrid67 committed
328
329
330
331
332
333
334
    def check_action(self, agent, action):
        transition_isValid = None
        possible_transitions = self.rail.get_transitions((*agent.position, agent.direction))
        num_transitions = np.count_nonzero(possible_transitions)

        new_direction = agent.direction
        # print(nbits,np.sum(possible_transitions))
spiglerg's avatar
spiglerg committed
335
        if action == RailEnvActions.MOVE_LEFT:
hagrid67's avatar
hagrid67 committed
336
337
338
339
            new_direction = agent.direction - 1
            if num_transitions <= 1:
                transition_isValid = False

spiglerg's avatar
spiglerg committed
340
        elif action == RailEnvActions.MOVE_RIGHT:
hagrid67's avatar
hagrid67 committed
341
342
343
344
345
346
            new_direction = agent.direction + 1
            if num_transitions <= 1:
                transition_isValid = False

        new_direction %= 4

spiglerg's avatar
spiglerg committed
347
        if action == RailEnvActions.MOVE_FORWARD:
hagrid67's avatar
hagrid67 committed
348
349
350
351
352
353
354
355
            if num_transitions == 1:
                # - dead-end, straight line or curved line;
                # new_direction will be the only valid transition
                # - take only available transition
                new_direction = np.argmax(possible_transitions)
                transition_isValid = True
        return new_direction, transition_isValid

356
357
    def _get_observations(self):
        self.obs_dict = {}
358
        self.debug_obs_dict = {}
359
360
        # for handle in self.agents_handles:
        for iAgent in range(self.get_num_agents()):
361
362
            self.obs_dict[iAgent] = self.obs_builder.get(iAgent)
        return self.obs_dict
363
364
365
366

    def render(self):
        # TODO:
        pass
367

maljx's avatar
maljx committed
368
369
370
371
    def get_full_state_msg(self):
        grid_data = self.rail.grid.tolist()
        agent_static_data = [agent.to_list() for agent in self.agents_static]
        agent_data = [agent.to_list() for agent in self.agents]
372
373
374
375
376

        msgpack.packb(grid_data)
        msgpack.packb(agent_data)
        msgpack.packb(agent_static_data)

maljx's avatar
maljx committed
377
378
379
        msg_data = {
            "grid": grid_data,
            "agents_static": agent_static_data,
380
            "agents": agent_data}
maljx's avatar
maljx committed
381
382
383
384
385
        return msgpack.packb(msg_data, use_bin_type=True)

    def get_agent_state_msg(self):
        agent_data = [agent.to_list() for agent in self.agents]
        msg_data = {
386
            "agents": agent_data}
maljx's avatar
maljx committed
387
388
389
390
391
392
393
394
395
        return msgpack.packb(msg_data, use_bin_type=True)

    def set_full_state_msg(self, msg_data):
        data = msgpack.unpackb(msg_data, use_list=False)
        self.rail.grid = np.array(data[b"grid"])
        self.agents_static = [EnvAgentStatic(d[0], d[1], d[2]) for d in data[b"agents_static"]]
        self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]]
        # setup with loaded data
        self.height, self.width = self.rail.grid.shape
396
397
        self.rail.height = self.height
        self.rail.width = self.width
maljx's avatar
maljx committed
398
399
400
401
402
403
404
405
406
407
408
        # self.agents = [None] * self.get_num_agents()
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)

    def save(self, filename):
        with open(filename, "wb") as file_out:
            file_out.write(self.get_full_state_msg())

    def load(self, filename):
        with open(filename, "rb") as file_in:
            load_data = file_in.read()
            self.set_full_state_msg(load_data)