rail_env.py 14.7 KB
Newer Older
1
2
3
4
5
6
7
"""
Definition of the RailEnv environment and related level-generation functions.

Generator functions are functions that take width, height and num_resets as arguments and return
a GridTransitionMap object.
"""
import numpy as np
maljx's avatar
maljx committed
8
import msgpack
9
10

from flatland.core.env import Environment
11
from flatland.envs.observations import TreeObsForRailEnv
12
from flatland.envs.generators import random_rail_generator
13
from flatland.envs.env_utils import get_new_position
14
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
15

16
17
# from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
# from flatland.core.transition_map import GridTransitionMap
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47


class RailEnv(Environment):
    """
    RailEnv environment class.

    RailEnv is an environment inspired by a (simplified version of) a rail
    network, in which agents (trains) have to navigate to their target
    locations in the shortest time possible, while at the same time cooperating
    to avoid bottlenecks.

    The valid actions in the environment are:
        0: do nothing
        1: turn left and move to the next cell
        2: move to the next cell in front of the agent
        3: turn right and move to the next cell

    Moving forward in a dead-end cell makes the agent turn 180 degrees and step
    to the cell it came from.

    The actions of the agents are executed in order of their handle to prevent
    deadlocks and to allow them to learn relative priorities.

    TODO: WRITE ABOUT THE REWARD FUNCTION, and possibly allow for alpha and
    beta to be passed as parameters to __init__().
    """

    def __init__(self,
                 width,
                 height,
spiglerg's avatar
spiglerg committed
48
                 rail_generator=random_rail_generator(),
49
50
51
52
53
54
55
56
                 number_of_agents=1,
                 obs_builder_object=TreeObsForRailEnv(max_depth=2)):
        """
        Environment init.

        Parameters
        -------
        rail_generator : function
57
58
59
60
            The rail_generator function is a function that takes the width,
            height and agents handles of a  rail environment, along with the number of times
            the env has been reset, and returns a GridTransitionMap object and a list of
            starting positions, targets, and initial orientations for agent handle.
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
            Implemented functions are:
                random_rail_generator : generate a random rail of given size
                rail_from_GridTransitionMap_generator(rail_map) : generate a rail from
                                        a GridTransitionMap object
                rail_from_manual_specifications_generator(rail_spec) : generate a rail from
                                        a rail specifications array
                TODO: generate_rail_from_saved_list or from list of ndarray bitmaps ---
        width : int
            The width of the rail map. Potentially in the future,
            a range of widths to sample from.
        height : int
            The height of the rail map. Potentially in the future,
            a range of heights to sample from.
        number_of_agents : int
            Number of agents to spawn on the map. Potentially in the future,
            a range of number of agents to sample from.
        obs_builder_object: ObservationBuilder object
            ObservationBuilder-derived object that takes builds observation
            vectors for each agent.
        """

        self.rail_generator = rail_generator
        self.rail = None
        self.width = width
        self.height = height

87
88
        # use get_num_agents() instead
        # self.number_of_agents = number_of_agents
89
90
91
92

        self.obs_builder = obs_builder_object
        self.obs_builder._set_env(self)

93
94
        self.actions = [0] * number_of_agents
        self.rewards = [0] * number_of_agents
95
96
        self.done = False

97
98
        self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False)

99
100
        self.obs_dict = {}
        self.rewards_dict = {}
Erik Nygren's avatar
Erik Nygren committed
101
        self.dev_obs_dict = {}
102
        # self.agents_handles = list(range(self.number_of_agents))
103
104
105
106

        # self.agents_position = []
        # self.agents_target = []
        # self.agents_direction = []
107
108
        self.agents = [None] * number_of_agents  # live agents
        self.agents_static = [None] * number_of_agents  # static agent information
109
110
        self.num_resets = 0
        self.reset()
hagrid67's avatar
hagrid67 committed
111
        self.num_resets = 0   # yes, set it to zero again!
112

113
114
        self.valid_positions = None

115
    # no more agent_handles
116
    def get_agent_handles(self):
117
118
119
120
121
122
123
        return range(self.get_num_agents())

    def get_num_agents(self, static=True):
        if static:
            return len(self.agents_static)
        else:
            return len(self.agents)
124

hagrid67's avatar
hagrid67 committed
125
126
127
128
129
130
131
    def add_agent_static(self, agent_static):
        """ Add static info for a single agent.
            Returns the index of the new agent.
        """
        self.agents_static.append(agent_static)
        return len(self.agents_static) - 1

132
133
    def restart_agents(self):
        """ Reset the agents to their starting positions defined in agents_static
hagrid67's avatar
hagrid67 committed
134
        """
135
136
137
138
139
140
        self.agents = EnvAgent.list_from_static(self.agents_static)

    def reset(self, regen_rail=True, replace_agents=True):
        """ if regen_rail then regenerate the rails.
            if replace_agents then regenerate the agents static.
            Relies on the rail_generator returning agent_static lists (pos, dir, target)
hagrid67's avatar
hagrid67 committed
141
        """
142
        tRailAgents = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
hagrid67's avatar
hagrid67 committed
143

144
        if regen_rail or self.rail is None:
hagrid67's avatar
hagrid67 committed
145
            self.rail = tRailAgents[0]
146

hagrid67's avatar
hagrid67 committed
147
        if replace_agents:
hagrid67's avatar
hagrid67 committed
148
149
150
            self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:4])

        # Take the agent static info and put (live) agents at the start positions
151
152
        # self.agents = EnvAgent.list_from_static(self.agents_static[:len(self.agents_handles)])
        self.restart_agents()
hagrid67's avatar
hagrid67 committed
153

154
155
        self.num_resets += 1

156
157
158
        # for handle in self.agents_handles:
        #    self.dones[handle] = False
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
159
        # perhaps dones should be part of each agent.
160

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        # Reset the state of the observation builder with the new environment
        self.obs_builder.reset()

        # Return the new observation vectors for each agent
        return self._get_observations()

    def step(self, action_dict):
        alpha = 1.0
        beta = 1.0

        invalid_action_penalty = -2
        step_penalty = -1 * alpha
        global_reward = 1 * beta

        # Reset the step rewards
176
        self.rewards_dict = dict()
177
178
179
180
        # for handle in self.agents_handles:
        #    self.rewards_dict[handle] = 0
        for iAgent in range(self.get_num_agents()):
            self.rewards_dict[iAgent] = 0
181
182

        if self.dones["__all__"]:
183
            self.rewards_dict = [r + global_reward for r in self.rewards_dict]
184
185
            return self._get_observations(), self.rewards_dict, self.dones, {}

186
187
188
        # for i in range(len(self.agents_handles)):
        for iAgent in range(self.get_num_agents()):
            # handle = self.agents_handles[i]
189
            transition_isValid = None
190
            agent = self.agents[iAgent]
Erik Nygren's avatar
Erik Nygren committed
191

192
            if iAgent not in action_dict:  # no action has been supplied for this agent
193
194
                continue

195
            if self.dones[iAgent]:  # this agent has already completed...
196
                continue
197
            action = action_dict[iAgent]
198
199
200

            if action < 0 or action > 3:
                print('ERROR: illegal action=', action,
201
                      'for agent with index=', iAgent)
202
203
204
                return

            if action > 0:
205
206
                # pos = agent.position #  self.agents_position[i]
                # direction = agent.direction # self.agents_direction[i]
207

Erik Nygren's avatar
Erik Nygren committed
208
209
210
                # compute number of possible transitions in the current
                # cell used to check for invalid actions

hagrid67's avatar
hagrid67 committed
211
212
213
                new_direction, transition_isValid = self.check_action(agent, action)

                new_position = get_new_position(agent.position, new_direction)
214
                # Is it a legal move?
hagrid67's avatar
hagrid67 committed
215
                # 1) transition allows the new_direction in the cell,
216
217
                # 2) the new cell is not empty (case 0),
                # 3) the cell is free, i.e., no agent is currently in that cell
218

219
220
221
222
223
224
225
226
227
228
                # if (
                #        new_position[1] >= self.width or
                #        new_position[0] >= self.height or
                #        new_position[0] < 0 or new_position[1] < 0):
                #    new_cell_isValid = False

                # if self.rail.get_transitions(new_position) == 0:
                #     new_cell_isValid = False

                new_cell_isValid = (
229
230
231
232
233
                    np.array_equal(  # Check the new position is still in the grid
                        new_position,
                        np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
                    and  # check the new position has some transitions (ie is not an empty cell)
                    self.rail.get_transitions(new_position) > 0)
234

Erik Nygren's avatar
Erik Nygren committed
235
                # If transition validity hasn't been checked yet.
hagrid67's avatar
hagrid67 committed
236
                if transition_isValid is None:
237
                    transition_isValid = self.rail.get_transition(
238
                        (*agent.position, agent.direction),
hagrid67's avatar
hagrid67 committed
239
                        new_direction)
240

241
242
243
244
245
246
247
248
                # cell_isFree = True
                # for j in range(self.number_of_agents):
                #    if self.agents_position[j] == new_position:
                #        cell_isFree = False
                #        break
                # Check the new position is not the same as any of the existing agent positions
                # (including itself, for simplicity, since it is moving)
                cell_isFree = not np.any(
249
                    np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
250
251

                if all([new_cell_isValid, transition_isValid, cell_isFree]):
hagrid67's avatar
hagrid67 committed
252
                    # move and change direction to face the new_direction that was
253
                    # performed
254
                    # self.agents_position[i] = new_position
hagrid67's avatar
hagrid67 committed
255
                    # self.agents_direction[i] = new_direction
256
                    agent.position = new_position
hagrid67's avatar
hagrid67 committed
257
258
                    agent.old_direction = agent.direction
                    agent.direction = new_direction
259
260
                else:
                    # the action was not valid, add penalty
261
                    self.rewards_dict[iAgent] += invalid_action_penalty
262
263

            # if agent is not in target position, add step penalty
264
265
266
267
            # if self.agents_position[i][0] == self.agents_target[i][0] and \
            #        self.agents_position[i][1] == self.agents_target[i][1]:
            #    self.dones[handle] = True
            if np.equal(agent.position, agent.target).all():
268
                self.dones[iAgent] = True
269
            else:
270
                self.rewards_dict[iAgent] += step_penalty
271
272

        # Check for end of episode + add global reward to all rewards!
273
274
275
276
277
278
279
        # num_agents_in_target_position = 0
        # for i in range(self.number_of_agents):
        #    if self.agents_position[i][0] == self.agents_target[i][0] and \
        #            self.agents_position[i][1] == self.agents_target[i][1]:
        #        num_agents_in_target_position += 1
        # if num_agents_in_target_position == self.number_of_agents:
        if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
280
            self.dones["__all__"] = True
281
            self.rewards_dict = [0 * r + global_reward for r in self.rewards_dict]
282
283
284

        # Reset the step actions (in case some agent doesn't 'register_action'
        # on the next step)
285
        self.actions = [0] * self.get_num_agents()
286
287
        return self._get_observations(), self.rewards_dict, self.dones, {}

hagrid67's avatar
hagrid67 committed
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
    def check_action(self, agent, action):
        transition_isValid = None
        possible_transitions = self.rail.get_transitions((*agent.position, agent.direction))
        num_transitions = np.count_nonzero(possible_transitions)

        new_direction = agent.direction
        # print(nbits,np.sum(possible_transitions))
        if action == 1:
            new_direction = agent.direction - 1
            if num_transitions <= 1:
                transition_isValid = False

        elif action == 3:
            new_direction = agent.direction + 1
            if num_transitions <= 1:
                transition_isValid = False

        new_direction %= 4

        if action == 2:
            if num_transitions == 1:
                # - dead-end, straight line or curved line;
                # new_direction will be the only valid transition
                # - take only available transition
                new_direction = np.argmax(possible_transitions)
                transition_isValid = True
        return new_direction, transition_isValid

316
317
    def _get_observations(self):
        self.obs_dict = {}
318
        self.debug_obs_dict = {}
319
320
        # for handle in self.agents_handles:
        for iAgent in range(self.get_num_agents()):
321
322
            self.obs_dict[iAgent] = self.obs_builder.get(iAgent)
        return self.obs_dict
323
324
325
326

    def render(self):
        # TODO:
        pass
327

maljx's avatar
maljx committed
328
329
330
331
    def get_full_state_msg(self):
        grid_data = self.rail.grid.tolist()
        agent_static_data = [agent.to_list() for agent in self.agents_static]
        agent_data = [agent.to_list() for agent in self.agents]
332
333
334
335
336

        msgpack.packb(grid_data)
        msgpack.packb(agent_data)
        msgpack.packb(agent_static_data)

maljx's avatar
maljx committed
337
338
339
        msg_data = {
            "grid": grid_data,
            "agents_static": agent_static_data,
340
            "agents": agent_data}
maljx's avatar
maljx committed
341
342
343
344
345
        return msgpack.packb(msg_data, use_bin_type=True)

    def get_agent_state_msg(self):
        agent_data = [agent.to_list() for agent in self.agents]
        msg_data = {
346
            "agents": agent_data}
maljx's avatar
maljx committed
347
348
349
350
351
352
353
354
355
        return msgpack.packb(msg_data, use_bin_type=True)

    def set_full_state_msg(self, msg_data):
        data = msgpack.unpackb(msg_data, use_list=False)
        self.rail.grid = np.array(data[b"grid"])
        self.agents_static = [EnvAgentStatic(d[0], d[1], d[2]) for d in data[b"agents_static"]]
        self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]]
        # setup with loaded data
        self.height, self.width = self.rail.grid.shape
356
357
        self.rail.height = self.height
        self.rail.width = self.width
maljx's avatar
maljx committed
358
359
360
361
362
363
364
365
366
367
368
        # self.agents = [None] * self.get_num_agents()
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)

    def save(self, filename):
        with open(filename, "wb") as file_out:
            file_out.write(self.get_full_state_msg())

    def load(self, filename):
        with open(filename, "rb") as file_in:
            load_data = file_in.read()
            self.set_full_state_msg(load_data)