rail_env.py 16 KB
Newer Older
1
2
3
4
5
6
"""
Definition of the RailEnv environment and related level-generation functions.

Generator functions are functions that take width, height and num_resets as arguments and return
a GridTransitionMap object.
"""
hagrid67's avatar
hagrid67 committed
7
8
# TODO:  _ this is a global method --> utils or remove later
# from inspect import currentframe
9

maljx's avatar
maljx committed
10
import msgpack
11
import numpy as np
12
13

from flatland.core.env import Environment
14
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
15
16
17
18
from flatland.envs.env_utils import get_new_position
from flatland.envs.generators import random_rail_generator
from flatland.envs.observations import TreeObsForRailEnv

19

20
21
# from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
# from flatland.core.transition_map import GridTransitionMap
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51


class RailEnv(Environment):
    """
    RailEnv environment class.

    RailEnv is an environment inspired by a (simplified version of) a rail
    network, in which agents (trains) have to navigate to their target
    locations in the shortest time possible, while at the same time cooperating
    to avoid bottlenecks.

    The valid actions in the environment are:
        0: do nothing
        1: turn left and move to the next cell
        2: move to the next cell in front of the agent
        3: turn right and move to the next cell

    Moving forward in a dead-end cell makes the agent turn 180 degrees and step
    to the cell it came from.

    The actions of the agents are executed in order of their handle to prevent
    deadlocks and to allow them to learn relative priorities.

    TODO: WRITE ABOUT THE REWARD FUNCTION, and possibly allow for alpha and
    beta to be passed as parameters to __init__().
    """

    def __init__(self,
                 width,
                 height,
spiglerg's avatar
spiglerg committed
52
                 rail_generator=random_rail_generator(),
53
                 number_of_agents=1,
u214892's avatar
u214892 committed
54
55
56
                 obs_builder_object=TreeObsForRailEnv(max_depth=2),
                 prediction_builder_object=None
                 ):
57
58
59
60
61
62
        """
        Environment init.

        Parameters
        -------
        rail_generator : function
63
64
65
66
            The rail_generator function is a function that takes the width,
            height and agents handles of a  rail environment, along with the number of times
            the env has been reset, and returns a GridTransitionMap object and a list of
            starting positions, targets, and initial orientations for agent handle.
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
            Implemented functions are:
                random_rail_generator : generate a random rail of given size
                rail_from_GridTransitionMap_generator(rail_map) : generate a rail from
                                        a GridTransitionMap object
                rail_from_manual_specifications_generator(rail_spec) : generate a rail from
                                        a rail specifications array
                TODO: generate_rail_from_saved_list or from list of ndarray bitmaps ---
        width : int
            The width of the rail map. Potentially in the future,
            a range of widths to sample from.
        height : int
            The height of the rail map. Potentially in the future,
            a range of heights to sample from.
        number_of_agents : int
            Number of agents to spawn on the map. Potentially in the future,
            a range of number of agents to sample from.
        obs_builder_object: ObservationBuilder object
            ObservationBuilder-derived object that takes builds observation
            vectors for each agent.
        """

        self.rail_generator = rail_generator
        self.rail = None
        self.width = width
        self.height = height

93
94
        # use get_num_agents() instead
        # self.number_of_agents = number_of_agents
95
96
97
98

        self.obs_builder = obs_builder_object
        self.obs_builder._set_env(self)

u214892's avatar
u214892 committed
99
100
101
102
103
        self.prediction_builder = prediction_builder_object
        if self.prediction_builder:
            self.prediction_builder._set_env(self)


104
        self.action_space = [1]
spiglerg's avatar
spiglerg committed
105
        self.observation_space = self.obs_builder.observation_space  # updated on resets?
106

107
108
        self.actions = [0] * number_of_agents
        self.rewards = [0] * number_of_agents
109
110
        self.done = False

111
112
        self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False)

113
114
        self.obs_dict = {}
        self.rewards_dict = {}
Erik Nygren's avatar
Erik Nygren committed
115
        self.dev_obs_dict = {}
116
        # self.agents_handles = list(range(self.number_of_agents))
117
118
119
120

        # self.agents_position = []
        # self.agents_target = []
        # self.agents_direction = []
121
122
        self.agents = [None] * number_of_agents  # live agents
        self.agents_static = [None] * number_of_agents  # static agent information
123
124
        self.num_resets = 0
        self.reset()
125
        self.num_resets = 0  # yes, set it to zero again!
126

127
128
        self.valid_positions = None

129
    # no more agent_handles
130
    def get_agent_handles(self):
131
132
133
134
135
136
137
        return range(self.get_num_agents())

    def get_num_agents(self, static=True):
        if static:
            return len(self.agents_static)
        else:
            return len(self.agents)
138

hagrid67's avatar
hagrid67 committed
139
140
141
142
143
144
145
    def add_agent_static(self, agent_static):
        """ Add static info for a single agent.
            Returns the index of the new agent.
        """
        self.agents_static.append(agent_static)
        return len(self.agents_static) - 1

146
147
    def restart_agents(self):
        """ Reset the agents to their starting positions defined in agents_static
hagrid67's avatar
hagrid67 committed
148
        """
149
150
151
152
153
154
        self.agents = EnvAgent.list_from_static(self.agents_static)

    def reset(self, regen_rail=True, replace_agents=True):
        """ if regen_rail then regenerate the rails.
            if replace_agents then regenerate the agents static.
            Relies on the rail_generator returning agent_static lists (pos, dir, target)
hagrid67's avatar
hagrid67 committed
155
        """
156
        tRailAgents = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
hagrid67's avatar
hagrid67 committed
157

158
        if regen_rail or self.rail is None:
hagrid67's avatar
hagrid67 committed
159
            self.rail = tRailAgents[0]
160

hagrid67's avatar
hagrid67 committed
161
        if replace_agents:
hagrid67's avatar
hagrid67 committed
162
163
164
            self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:4])

        # Take the agent static info and put (live) agents at the start positions
165
166
        # self.agents = EnvAgent.list_from_static(self.agents_static[:len(self.agents_handles)])
        self.restart_agents()
hagrid67's avatar
hagrid67 committed
167

168
169
        self.num_resets += 1

170
171
172
        # for handle in self.agents_handles:
        #    self.dones[handle] = False
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
173
        # perhaps dones should be part of each agent.
174

175
176
        # Reset the state of the observation builder with the new environment
        self.obs_builder.reset()
spiglerg's avatar
spiglerg committed
177
        self.observation_space = self.obs_builder.observation_space  # <-- change on reset?
178
179
180
181
182
183
184
185
186
187
188
189
190

        # Return the new observation vectors for each agent
        return self._get_observations()

    def step(self, action_dict):
        alpha = 1.0
        beta = 1.0

        invalid_action_penalty = -2
        step_penalty = -1 * alpha
        global_reward = 1 * beta

        # Reset the step rewards
191
        self.rewards_dict = dict()
192
193
194
195
        # for handle in self.agents_handles:
        #    self.rewards_dict[handle] = 0
        for iAgent in range(self.get_num_agents()):
            self.rewards_dict[iAgent] = 0
196
197

        if self.dones["__all__"]:
198
            self.rewards_dict = [r + global_reward for r in self.rewards_dict]
199
200
            return self._get_observations(), self.rewards_dict, self.dones, {}

201
202
203
        # for i in range(len(self.agents_handles)):
        for iAgent in range(self.get_num_agents()):
            # handle = self.agents_handles[i]
204
            transition_isValid = None
205
            agent = self.agents[iAgent]
Erik Nygren's avatar
Erik Nygren committed
206

207
            if iAgent not in action_dict:  # no action has been supplied for this agent
208
209
                continue

210
            if self.dones[iAgent]:  # this agent has already completed...
Egli Adrian (IT-SCI-API-PFI)'s avatar
...    
Egli Adrian (IT-SCI-API-PFI) committed
211
212
                # print("rail_env.py @", currentframe().f_back.f_lineno, " agent ", iAgent,
                #       "has already completed : why action will not be executed!!!!? ADRIAN")
213
                continue
214
            action = action_dict[iAgent]
215
216
217

            if action < 0 or action > 3:
                print('ERROR: illegal action=', action,
218
                      'for agent with index=', iAgent)
219
220
221
                return

            if action > 0:
u214892's avatar
u214892 committed
222
223
224
                cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = self._check_action_on_agent(action,
                                                                                                                             agent,
                                                                                                                             transition_isValid)
225
226

                if all([new_cell_isValid, transition_isValid, cell_isFree]):
hagrid67's avatar
hagrid67 committed
227
                    # move and change direction to face the new_direction that was
228
                    # performed
229
                    # self.agents_position[i] = new_position
hagrid67's avatar
hagrid67 committed
230
231
                    # self.agents_direction[i] = new_direction
                    agent.old_direction = agent.direction
232
233
                    agent.old_position = agent.position
                    agent.position = new_position
hagrid67's avatar
hagrid67 committed
234
                    agent.direction = new_direction
235
236
                else:
                    # the action was not valid, add penalty
237
                    self.rewards_dict[iAgent] += invalid_action_penalty
238
239

            # if agent is not in target position, add step penalty
240
241
242
243
            # if self.agents_position[i][0] == self.agents_target[i][0] and \
            #        self.agents_position[i][1] == self.agents_target[i][1]:
            #    self.dones[handle] = True
            if np.equal(agent.position, agent.target).all():
244
                self.dones[iAgent] = True
245
            else:
246
                self.rewards_dict[iAgent] += step_penalty
247
248

        # Check for end of episode + add global reward to all rewards!
249
250
251
252
253
254
255
        # num_agents_in_target_position = 0
        # for i in range(self.number_of_agents):
        #    if self.agents_position[i][0] == self.agents_target[i][0] and \
        #            self.agents_position[i][1] == self.agents_target[i][1]:
        #        num_agents_in_target_position += 1
        # if num_agents_in_target_position == self.number_of_agents:
        if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
256
            self.dones["__all__"] = True
257
            self.rewards_dict = [0 * r + global_reward for r in self.rewards_dict]
258
259
260

        # Reset the step actions (in case some agent doesn't 'register_action'
        # on the next step)
261
        self.actions = [0] * self.get_num_agents()
262
263
        return self._get_observations(), self.rewards_dict, self.dones, {}

u214892's avatar
u214892 committed
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
    def _check_action_on_agent(self, action, agent):
        # pos = agent.position #  self.agents_position[i]
        # direction = agent.direction # self.agents_direction[i]
        # compute number of possible transitions in the current
        # cell used to check for invalid actions
        new_direction, transition_isValid = self.check_action(agent, action)
        new_position = get_new_position(agent.position, new_direction)
        # Is it a legal move?
        # 1) transition allows the new_direction in the cell,
        # 2) the new cell is not empty (case 0),
        # 3) the cell is free, i.e., no agent is currently in that cell
        # if (
        #        new_position[1] >= self.width or
        #        new_position[0] >= self.height or
        #        new_position[0] < 0 or new_position[1] < 0):
        #    new_cell_isValid = False
        # if self.rail.get_transitions(new_position) == 0:
        #     new_cell_isValid = False
        new_cell_isValid = (
            np.array_equal(  # Check the new position is still in the grid
                new_position,
                np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
            and  # check the new position has some transitions (ie is not an empty cell)
            self.rail.get_transitions(new_position) > 0)
        # If transition validity hasn't been checked yet.
        if transition_isValid is None:
            transition_isValid = self.rail.get_transition(
                (*agent.position, agent.direction),
                new_direction)
        # cell_isFree = True
        # for j in range(self.number_of_agents):
        #    if self.agents_position[j] == new_position:
        #        cell_isFree = False
        #        break
        # Check the new position is not the same as any of the existing agent positions
        # (including itself, for simplicity, since it is moving)
        cell_isFree = not np.any(
            np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
        return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid

u214892's avatar
u214892 committed
304
305
306
307
308
309
    def predict(self):
        if not self.prediction_builder:
            return {}
        return  self.prediction_builder.get()


hagrid67's avatar
hagrid67 committed
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
    def check_action(self, agent, action):
        transition_isValid = None
        possible_transitions = self.rail.get_transitions((*agent.position, agent.direction))
        num_transitions = np.count_nonzero(possible_transitions)

        new_direction = agent.direction
        # print(nbits,np.sum(possible_transitions))
        if action == 1:
            new_direction = agent.direction - 1
            if num_transitions <= 1:
                transition_isValid = False

        elif action == 3:
            new_direction = agent.direction + 1
            if num_transitions <= 1:
                transition_isValid = False

        new_direction %= 4

        if action == 2:
            if num_transitions == 1:
                # - dead-end, straight line or curved line;
                # new_direction will be the only valid transition
                # - take only available transition
                new_direction = np.argmax(possible_transitions)
                transition_isValid = True
        return new_direction, transition_isValid

338
339
    def _get_observations(self):
        self.obs_dict = {}
340
        self.debug_obs_dict = {}
341
342
        # for handle in self.agents_handles:
        for iAgent in range(self.get_num_agents()):
343
344
            self.obs_dict[iAgent] = self.obs_builder.get(iAgent)
        return self.obs_dict
345

u214892's avatar
u214892 committed
346
347
348
349
350
    def _get_predictions(self):
        if not self.prediction_builder:
            return {}
        return {}

351
352
353
    def render(self):
        # TODO:
        pass
354

maljx's avatar
maljx committed
355
356
357
358
    def get_full_state_msg(self):
        grid_data = self.rail.grid.tolist()
        agent_static_data = [agent.to_list() for agent in self.agents_static]
        agent_data = [agent.to_list() for agent in self.agents]
359
360
361
362
363

        msgpack.packb(grid_data)
        msgpack.packb(agent_data)
        msgpack.packb(agent_static_data)

maljx's avatar
maljx committed
364
365
366
        msg_data = {
            "grid": grid_data,
            "agents_static": agent_static_data,
367
            "agents": agent_data}
maljx's avatar
maljx committed
368
369
370
371
372
        return msgpack.packb(msg_data, use_bin_type=True)

    def get_agent_state_msg(self):
        agent_data = [agent.to_list() for agent in self.agents]
        msg_data = {
373
            "agents": agent_data}
maljx's avatar
maljx committed
374
375
376
377
378
379
380
381
382
        return msgpack.packb(msg_data, use_bin_type=True)

    def set_full_state_msg(self, msg_data):
        data = msgpack.unpackb(msg_data, use_list=False)
        self.rail.grid = np.array(data[b"grid"])
        self.agents_static = [EnvAgentStatic(d[0], d[1], d[2]) for d in data[b"agents_static"]]
        self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]]
        # setup with loaded data
        self.height, self.width = self.rail.grid.shape
383
384
        self.rail.height = self.height
        self.rail.width = self.width
maljx's avatar
maljx committed
385
386
387
388
389
390
391
392
393
394
395
        # self.agents = [None] * self.get_num_agents()
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)

    def save(self, filename):
        with open(filename, "wb") as file_out:
            file_out.write(self.get_full_state_msg())

    def load(self, filename):
        with open(filename, "rb") as file_in:
            load_data = file_in.read()
            self.set_full_state_msg(load_data)