rail_env.py 27.5 KB
Newer Older
1
"""
2
Definition of the RailEnv environment.
3
"""
hagrid67's avatar
hagrid67 committed
4
# TODO:  _ this is a global method --> utils or remove later
u214892's avatar
u214892 committed
5
import warnings
6
from enum import IntEnum
u214892's avatar
u214892 committed
7
from typing import List
8

maljx's avatar
maljx committed
9
import msgpack
10
import msgpack_numpy as m
11
import numpy as np
12
13

from flatland.core.env import Environment
u214892's avatar
u214892 committed
14
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
15
from flatland.core.transition_map import GridTransitionMap
16
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
17
from flatland.envs.distance_map import DistanceMap
18
from flatland.envs.observations import TreeObsForRailEnv
u214892's avatar
u214892 committed
19
from flatland.envs.rail_generators import random_rail_generator, RailGenerator
20
from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
21

22
23
m.patch()

24

spiglerg's avatar
spiglerg committed
25
class RailEnvActions(IntEnum):
26
    DO_NOTHING = 0  # implies change of direction in a dead-end!
spiglerg's avatar
spiglerg committed
27
28
29
30
31
    MOVE_LEFT = 1
    MOVE_FORWARD = 2
    MOVE_RIGHT = 3
    STOP_MOVING = 4

32
33
34
35
36
37
38
39
40
41
    @staticmethod
    def to_char(a: int):
        return {
            0: 'B',
            1: 'L',
            2: 'F',
            3: 'R',
            4: 'S',
        }[a]

u214892's avatar
u214892 committed
42

43
44
45
46
47
48
49
50
51
52
class RailEnv(Environment):
    """
    RailEnv environment class.

    RailEnv is an environment inspired by a (simplified version of) a rail
    network, in which agents (trains) have to navigate to their target
    locations in the shortest time possible, while at the same time cooperating
    to avoid bottlenecks.

    The valid actions in the environment are:
53
54
55
56
57
58

     -   0: do nothing (continue moving or stay still)
     -   1: turn left at switch and move to the next cell; if the agent was not moving, movement is started
     -   2: move to the next cell in front of the agent; if the agent was not moving, movement is started
     -   3: turn right at switch and move to the next cell; if the agent was not moving, movement is started
     -   4: stop moving
59
60
61
62

    Moving forward in a dead-end cell makes the agent turn 180 degrees and step
    to the cell it came from.

63

64
65
66
    The actions of the agents are executed in order of their handle to prevent
    deadlocks and to allow them to learn relative priorities.

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    Reward Function:

    It costs each agent a step_penalty for every time-step taken in the environment. Independent of the movement
    of the agent. Currently all other penalties such as penalty for stopping, starting and invalid actions are set to 0.

    alpha = 1
    beta = 1
    Reward function parameters:

    - invalid_action_penalty = 0
    - step_penalty = -alpha
    - global_reward = beta
    - stop_penalty = 0  # penalty for stopping a moving agent
    - start_penalty = 0  # penalty for starting a stopped agent

82
83
    Stochastic malfunctioning of trains:
    Trains in RailEnv can malfunction if they are halted too often (either by their own choice or because an invalid
spiglerg's avatar
spiglerg committed
84
85
    action or cell is selected.

86
87
88
    Every time an agent stops, an agent has a certain probability of malfunctioning. Malfunctions of trains follow a
    poisson process with a certain rate. Not all trains will be affected by malfunctions during episodes to keep
    complexity managable.
spiglerg's avatar
spiglerg committed
89
90
91
92

    TODO: currently, the parameters that control the stochasticity of the environment are hard-coded in init().
    For Round 2, they will be passed to the constructor as arguments, to allow for more flexibility.

93
    """
u214892's avatar
u214892 committed
94
95
96
97
98
99
100
101
102
    alpha = 1.0
    beta = 1.0
    # Epsilon to avoid rounding errors
    epsilon = 0.01
    invalid_action_penalty = 0  # previously -2; GIACOMO: we decided that invalid actions will carry no penalty
    step_penalty = -1 * alpha
    global_reward = 1 * beta
    stop_penalty = 0  # penalty for stopping a moving agent
    start_penalty = 0  # penalty for starting a stopped agent
103
104
105
106

    def __init__(self,
                 width,
                 height,
u214892's avatar
u214892 committed
107
                 rail_generator: RailGenerator = random_rail_generator(),
108
                 schedule_generator: ScheduleGenerator = random_schedule_generator(),
109
                 number_of_agents=1,
u214892's avatar
u214892 committed
110
                 obs_builder_object=TreeObsForRailEnv(max_depth=2),
111
112
                 max_episode_steps=None,
                 stochastic_data=None
u214892's avatar
u214892 committed
113
                 ):
114
115
116
117
118
119
        """
        Environment init.

        Parameters
        -------
        rail_generator : function
120
121
122
123
            The rail_generator function is a function that takes the width,
            height and agents handles of a  rail environment, along with the number of times
            the env has been reset, and returns a GridTransitionMap object and a list of
            starting positions, targets, and initial orientations for agent handle.
124
            The rail_generator can pass a distance map in the hints or information for specific schedule_generators.
u214892's avatar
u214892 committed
125
            Implementations can be found in flatland/envs/rail_generators.py
126
127
        schedule_generator : function
            The schedule_generator function is a function that takes the grid, the number of agents and optional hints
u214892's avatar
u214892 committed
128
            and returns a list of starting positions, targets, initial orientations and speed for all agent handles.
u214892's avatar
u214892 committed
129
            Implementations can be found in flatland/envs/schedule_generators.py
130
131
132
133
134
135
136
137
138
139
140
141
        width : int
            The width of the rail map. Potentially in the future,
            a range of widths to sample from.
        height : int
            The height of the rail map. Potentially in the future,
            a range of heights to sample from.
        number_of_agents : int
            Number of agents to spawn on the map. Potentially in the future,
            a range of number of agents to sample from.
        obs_builder_object: ObservationBuilder object
            ObservationBuilder-derived object that takes builds observation
            vectors for each agent.
spiglerg's avatar
spiglerg committed
142
143
        max_episode_steps : int or None

Erik Nygren's avatar
Erik Nygren committed
144
        file_name: you can load a pickle file. from previously saved *.pkl file
Erik Nygren's avatar
Erik Nygren committed
145

146
        """
147
        super().__init__()
148

u214892's avatar
u214892 committed
149
        self.rail_generator: RailGenerator = rail_generator
150
        self.schedule_generator: ScheduleGenerator = schedule_generator
151
        self.rail_generator = rail_generator
u214892's avatar
u214892 committed
152
        self.rail: GridTransitionMap = None
153
154
155
        self.width = width
        self.height = height

156
        self.rewards = [0] * number_of_agents
157
        self.done = False
158
159
        self.obs_builder = obs_builder_object
        self.obs_builder._set_env(self)
160

spiglerg's avatar
spiglerg committed
161
162
163
        self._max_episode_steps = max_episode_steps
        self._elapsed_steps = 0

164
165
        self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False)

166
167
        self.obs_dict = {}
        self.rewards_dict = {}
Erik Nygren's avatar
Erik Nygren committed
168
        self.dev_obs_dict = {}
Erik Nygren's avatar
Erik Nygren committed
169
        self.dev_pred_dict = {}
170

u214892's avatar
u214892 committed
171
172
        self.agents: List[EnvAgent] = [None] * number_of_agents  # live agents
        self.agents_static: List[EnvAgentStatic] = [None] * number_of_agents  # static agent information
173
        self.num_resets = 0
174
        self.distance_map = DistanceMap(self.agents, self.height, self.width)
175
176
177
178

        self.action_space = [1]
        self.observation_space = self.obs_builder.observation_space  # updated on resets?

179
        # Stochastic train malfunctioning parameters
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        if stochastic_data is not None:
            prop_malfunction = stochastic_data['prop_malfunction']
            mean_malfunction_rate = stochastic_data['malfunction_rate']
            malfunction_min_duration = stochastic_data['min_duration']
            malfunction_max_duration = stochastic_data['max_duration']
        else:
            prop_malfunction = 0.
            mean_malfunction_rate = 0.
            malfunction_min_duration = 0.
            malfunction_max_duration = 0.

        # percentage of malfunctioning trains
        self.proportion_malfunctioning_trains = prop_malfunction

        # Mean malfunction in number of stops
        self.mean_malfunction_rate = mean_malfunction_rate
196
197

        # Uniform distribution parameters for malfunction duration
198
199
        self.min_number_of_steps_broken = malfunction_min_duration
        self.max_number_of_steps_broken = malfunction_max_duration
200
201

        # Rest environment
202
        self.reset()
203
        self.num_resets = 0  # yes, set it to zero again!
204

205
206
        self.valid_positions = None

207
    # no more agent_handles
208
    def get_agent_handles(self):
209
210
211
212
213
214
215
        return range(self.get_num_agents())

    def get_num_agents(self, static=True):
        if static:
            return len(self.agents_static)
        else:
            return len(self.agents)
216

hagrid67's avatar
hagrid67 committed
217
218
219
220
221
222
223
    def add_agent_static(self, agent_static):
        """ Add static info for a single agent.
            Returns the index of the new agent.
        """
        self.agents_static.append(agent_static)
        return len(self.agents_static) - 1

224
225
    def restart_agents(self):
        """ Reset the agents to their starting positions defined in agents_static
hagrid67's avatar
hagrid67 committed
226
        """
227
228
229
230
231
232
        self.agents = EnvAgent.list_from_static(self.agents_static)

    def reset(self, regen_rail=True, replace_agents=True):
        """ if regen_rail then regenerate the rails.
            if replace_agents then regenerate the agents static.
            Relies on the rail_generator returning agent_static lists (pos, dir, target)
hagrid67's avatar
hagrid67 committed
233
        """
234
235

        # TODO can we not put 'self.rail_generator(..)' into 'if regen_rail or self.rail is None' condition?
u214892's avatar
u214892 committed
236
        rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
hagrid67's avatar
hagrid67 committed
237

238
        if optionals and 'distance_map' in optionals:
239
            self.distance_map.set(optionals['distance_map'])
240

241
        if regen_rail or self.rail is None:
u214892's avatar
u214892 committed
242
            self.rail = rail
243
            self.height, self.width = self.rail.grid.shape
u214892's avatar
u214892 committed
244
245
            for r in range(self.height):
                for c in range(self.width):
u214892's avatar
u214892 committed
246
247
                    rc_pos = (r, c)
                    check = self.rail.cell_neighbours_valid(rc_pos, True)
u214892's avatar
u214892 committed
248
                    if not check:
u214892's avatar
u214892 committed
249
                        warnings.warn("Invalid grid at {} -> {}".format(rc_pos, check))
250

hagrid67's avatar
hagrid67 committed
251
        if replace_agents:
u214892's avatar
u214892 committed
252
253
254
255
            agents_hints = None
            if optionals and 'agents_hints' in optionals:
                agents_hints = optionals['agents_hints']
            self.agents_static = EnvAgentStatic.from_lists(
u214892's avatar
u214892 committed
256
                *self.schedule_generator(self.rail, self.get_num_agents(), agents_hints))
257
        self.restart_agents()
hagrid67's avatar
hagrid67 committed
258

259
260
        for i_agent in range(self.get_num_agents()):
            agent = self.agents[i_agent]
261
262

            # A proportion of agent in the environment will receive a positive malfunction rate
263
            if np.random.random() < self.proportion_malfunctioning_trains:
264
                agent.malfunction_data['malfunction_rate'] = self.mean_malfunction_rate
265

266
            agent.malfunction_data['malfunction'] = 0
Erik Nygren's avatar
Erik Nygren committed
267

u214892's avatar
u214892 committed
268
            self._agent_new_malfunction(i_agent, RailEnvActions.DO_NOTHING)
Erik Nygren's avatar
Erik Nygren committed
269

270
        self.num_resets += 1
spiglerg's avatar
spiglerg committed
271
        self._elapsed_steps = 0
272

u214892's avatar
u214892 committed
273
        # TODO perhaps dones should be part of each agent.
274
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
275

276
277
        # Reset the state of the observation builder with the new environment
        self.obs_builder.reset()
spiglerg's avatar
spiglerg committed
278
        self.observation_space = self.obs_builder.observation_space  # <-- change on reset?
279
280
281
282

        # Return the new observation vectors for each agent
        return self._get_observations()

u214892's avatar
u214892 committed
283
284
285
286
    def _agent_new_malfunction(self, i_agent, action) -> bool:
        """
        Returns true if the agent enters into malfunction. (False, if not broken down or already broken down before).
        """
u214892's avatar
u214892 committed
287
288
        agent = self.agents[i_agent]

289
        # Decrease counter for next event
290
291
        if agent.malfunction_data['malfunction_rate'] > 0:
            agent.malfunction_data['next_malfunction'] -= 1
292

293
        # Only agents that have a positive rate for malfunctions and are not currently broken are considered
u214892's avatar
u214892 committed
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
        # If counter has come to zero --> Agent has malfunction
        # set next malfunction time and duration of current malfunction
        if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data['malfunction'] and \
            agent.malfunction_data['next_malfunction'] <= 0:
            # Increase number of malfunctions
            agent.malfunction_data['nr_malfunctions'] += 1

            # Next malfunction in number of stops
            next_breakdown = int(
                np.random.exponential(scale=agent.malfunction_data['malfunction_rate']))
            agent.malfunction_data['next_malfunction'] = next_breakdown

            # Duration of current malfunction
            num_broken_steps = np.random.randint(self.min_number_of_steps_broken,
                                                 self.max_number_of_steps_broken + 1) + 1
            agent.malfunction_data['malfunction'] = num_broken_steps

            return True
u214892's avatar
u214892 committed
312
        return False
u214892's avatar
u214892 committed
313

u214892's avatar
u214892 committed
314
    # TODO refactor to decrease length of this method!
spiglerg's avatar
spiglerg committed
315
    def step(self, action_dict_):
spiglerg's avatar
spiglerg committed
316
317
        self._elapsed_steps += 1

318
        # Reset the step rewards
319
        self.rewards_dict = dict()
u214892's avatar
u214892 committed
320
321
        for i_agent in range(self.get_num_agents()):
            self.rewards_dict[i_agent] = 0
322
323

        if self.dones["__all__"]:
u214892's avatar
u214892 committed
324
            self.rewards_dict = {i: r + self.global_reward for i, r in self.rewards_dict.items()}
u214892's avatar
u214892 committed
325
            info_dict = {
326
                'action_required': {i: False for i in range(self.get_num_agents())},
327
                'malfunction': {i: 0 for i in range(self.get_num_agents())},
u214892's avatar
u214892 committed
328
                'speed': {i: 0 for i in range(self.get_num_agents())}
u214892's avatar
u214892 committed
329
330
            }
            return self._get_observations(), self.rewards_dict, self.dones, info_dict
331

332
        for i_agent in range(self.get_num_agents()):
333

u214892's avatar
u214892 committed
334
            if self.dones[i_agent]:  # this agent has already completed...
335
                continue
336

u214892's avatar
u214892 committed
337
338
339
            agent = self.agents[i_agent]
            agent.old_direction = agent.direction
            agent.old_position = agent.position
spiglerg's avatar
spiglerg committed
340

u214892's avatar
u214892 committed
341
342
343
344
345
            # No action has been supplied for this agent -> set DO_NOTHING as default
            if i_agent not in action_dict_:
                action = RailEnvActions.DO_NOTHING
            else:
                action = action_dict_[i_agent]
u214892's avatar
u214892 committed
346

u214892's avatar
u214892 committed
347
348
            if action < 0 or action > len(RailEnvActions):
                print('ERROR: illegal action=', action,
u214892's avatar
u214892 committed
349
350
                      'for agent with index=', i_agent,
                      '"DO NOTHING" will be executed instead')
u214892's avatar
u214892 committed
351
352
353
                action = RailEnvActions.DO_NOTHING

            # Check if agent breaks at this step
u214892's avatar
u214892 committed
354
            new_malfunction = self._agent_new_malfunction(i_agent, action)
u214892's avatar
u214892 committed
355

u214892's avatar
u214892 committed
356
357
358
            # Is the agent at the beginning of the cell? Then, it can take an action
            # Design choice (Erik+Christian):
            #  as long as we're broken down at the beginning of the cell, we can choose other actions!
u214892's avatar
u214892 committed
359
            if agent.speed_data['position_fraction'] == 0.0:
u214892's avatar
u214892 committed
360
361
362
363
                if action == RailEnvActions.DO_NOTHING and agent.moving:
                    # Keep moving
                    action = RailEnvActions.MOVE_FORWARD

u214892's avatar
u214892 committed
364
                if action == RailEnvActions.STOP_MOVING and agent.moving:
u214892's avatar
u214892 committed
365
366
367
368
                    # Only allow halting an agent on entering new cells.
                    agent.moving = False
                    self.rewards_dict[i_agent] += self.stop_penalty

u214892's avatar
u214892 committed
369
370
                if not agent.moving and not (
                    action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
u214892's avatar
u214892 committed
371
372
373
374
                    # Allow agent to start with any forward or direction action
                    agent.moving = True
                    self.rewards_dict[i_agent] += self.start_penalty

u214892's avatar
u214892 committed
375
376
377
                # Store the action
                if agent.moving and action not in [RailEnvActions.DO_NOTHING, RailEnvActions.STOP_MOVING]:
                    _, new_cell_valid, new_direction, new_position, transition_valid = \
u214892's avatar
u214892 committed
378
                        self._check_action_on_agent(action, agent)
u214892's avatar
u214892 committed
379

u214892's avatar
u214892 committed
380
381
382
383
384
385
                    if all([new_cell_valid, transition_valid]):
                        agent.speed_data['transition_action_on_cellexit'] = action
                    else:
                        # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
                        # try to keep moving forward!
                        if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT):
u214892's avatar
u214892 committed
386
                            _, new_cell_valid, new_direction, new_position, transition_valid = \
u214892's avatar
u214892 committed
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
                                self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)

                            if all([new_cell_valid, transition_valid]):
                                agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
                            else:
                                # If the agent cannot move due to an invalid transition, we set its state to not moving
                                self.rewards_dict[i_agent] += self.invalid_action_penalty
                                self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
                                self.rewards_dict[i_agent] += self.stop_penalty
                                agent.moving = False

                        else:
                            # If the agent cannot move due to an invalid transition, we set its state to not moving
                            self.rewards_dict[i_agent] += self.invalid_action_penalty
                            self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
                            self.rewards_dict[i_agent] += self.stop_penalty
                            agent.moving = False

u214892's avatar
u214892 committed
405
406
            # if we've just broken in this step, nothing else to do
            if new_malfunction:
u214892's avatar
u214892 committed
407
                continue
u214892's avatar
u214892 committed
408

u214892's avatar
u214892 committed
409
            # The train was broken before...
410
            if agent.malfunction_data['malfunction'] > 0:
u214892's avatar
u214892 committed
411

Erik Nygren's avatar
Erik Nygren committed
412
                # Last step of malfunction --> Agent starts moving again after getting fixed
413
414
415
                if agent.malfunction_data['malfunction'] < 2:
                    agent.malfunction_data['malfunction'] -= 1
                    self.agents[i_agent].moving = True
u214892's avatar
u214892 committed
416
                    action = RailEnvActions.DO_NOTHING
417
418
419
420
421

                else:
                    agent.malfunction_data['malfunction'] -= 1

                    # Broken agents are stopped
u214892's avatar
u214892 committed
422
                    self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
423
424
425
426
                    self.agents[i_agent].moving = False

                    # Nothing left to do with broken agent
                    continue
427

428
            # Now perform a movement.
u214892's avatar
u214892 committed
429
            # If agent.moving, increment the position_fraction by the speed of the agent
430
            # If the new position fraction is >= 1, reset to 0, and perform the stored
u214892's avatar
u214892 committed
431
            #   transition_action_on_cellexit if the cell is free.
u214892's avatar
u214892 committed
432
            if agent.moving:
433
434

                agent.speed_data['position_fraction'] += agent.speed_data['speed']
u214892's avatar
u214892 committed
435
436
437
438
                if agent.speed_data['position_fraction'] >= 1.0:
                    # Perform stored action to transition to the next cell as soon as cell is free
                    # Notice that we've already check new_cell_valid and transition valid when we stored the action,
                    # so we only have to check cell_free now!
Erik Nygren's avatar
Erik Nygren committed
439
440
441
442
443
444
445
446

                    # cell and transition validity was checked when we stored transition_action_on_cellexit!
                    cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent(
                        agent.speed_data['transition_action_on_cellexit'], agent)

                    if cell_free:
                        agent.position = new_position
                        agent.direction = new_direction
u214892's avatar
u214892 committed
447
                        agent.speed_data['position_fraction'] = 0.0
448

spiglerg's avatar
spiglerg committed
449
            if np.equal(agent.position, agent.target).all():
u214892's avatar
u214892 committed
450
                self.dones[i_agent] = True
451
                agent.moving = False
spiglerg's avatar
spiglerg committed
452
            else:
u214892's avatar
u214892 committed
453
                self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
spiglerg's avatar
spiglerg committed
454

455
        # Check for end of episode + add global reward to all rewards!
456
        if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
457
            self.dones["__all__"] = True
u214892's avatar
u214892 committed
458
            self.rewards_dict = {i: 0 * r + self.global_reward for i, r in self.rewards_dict.items()}
459

spiglerg's avatar
spiglerg committed
460
461
462
463
464
        if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps):
            self.dones["__all__"] = True
            for k in self.dones.keys():
                self.dones[k] = True

465
        action_required_agents = {
u214892's avatar
u214892 committed
466
            i: self.agents[i].speed_data['position_fraction'] == 0.0 for i in range(self.get_num_agents())
u214892's avatar
u214892 committed
467
468
469
470
471
        }
        malfunction_agents = {
            i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents())
        }
        speed_agents = {i: self.agents[i].speed_data['speed'] for i in range(self.get_num_agents())}
472

u214892's avatar
u214892 committed
473
        info_dict = {
474
            'action_required': action_required_agents,
u214892's avatar
u214892 committed
475
476
            'malfunction': malfunction_agents,
            'speed': speed_agents
u214892's avatar
u214892 committed
477
478
479
        }

        return self._get_observations(), self.rewards_dict, self.dones, info_dict
480

u214892's avatar
u214892 committed
481
    def _check_action_on_agent(self, action, agent):
u214892's avatar
u214892 committed
482

u214892's avatar
u214892 committed
483
484
        # compute number of possible transitions in the current
        # cell used to check for invalid actions
485
        new_direction, transition_valid = self.check_action(agent, action)
u214892's avatar
u214892 committed
486
        new_position = get_new_position(agent.position, new_direction)
487

u214892's avatar
u214892 committed
488
        # Is it a legal move?
spiglerg's avatar
spiglerg committed
489
490
491
        # 1) transition allows the new_direction in the cell,
        # 2) the new cell is not empty (case 0),
        # 3) the cell is free, i.e., no agent is currently in that cell
492
        new_cell_valid = (
spiglerg's avatar
spiglerg committed
493
494
495
496
            np.array_equal(  # Check the new position is still in the grid
                new_position,
                np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
            and  # check the new position has some transitions (ie is not an empty cell)
u214892's avatar
u214892 committed
497
            self.rail.get_full_transitions(*new_position) > 0)
498

spiglerg's avatar
spiglerg committed
499
        # If transition validity hasn't been checked yet.
500
501
        if transition_valid is None:
            transition_valid = self.rail.get_transition(
spiglerg's avatar
spiglerg committed
502
503
                (*agent.position, agent.direction),
                new_direction)
504

spiglerg's avatar
spiglerg committed
505
506
        # Check the new position is not the same as any of the existing agent positions
        # (including itself, for simplicity, since it is moving)
507
        cell_free = not np.any(
spiglerg's avatar
spiglerg committed
508
            np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
509
        return cell_free, new_cell_valid, new_direction, new_position, transition_valid
spiglerg's avatar
spiglerg committed
510

hagrid67's avatar
hagrid67 committed
511
    def check_action(self, agent, action):
512
        transition_valid = None
u214892's avatar
u214892 committed
513
        possible_transitions = self.rail.get_transitions(*agent.position, agent.direction)
hagrid67's avatar
hagrid67 committed
514
515
516
        num_transitions = np.count_nonzero(possible_transitions)

        new_direction = agent.direction
spiglerg's avatar
spiglerg committed
517
        if action == RailEnvActions.MOVE_LEFT:
hagrid67's avatar
hagrid67 committed
518
519
            new_direction = agent.direction - 1
            if num_transitions <= 1:
520
                transition_valid = False
hagrid67's avatar
hagrid67 committed
521

spiglerg's avatar
spiglerg committed
522
        elif action == RailEnvActions.MOVE_RIGHT:
hagrid67's avatar
hagrid67 committed
523
524
            new_direction = agent.direction + 1
            if num_transitions <= 1:
525
                transition_valid = False
hagrid67's avatar
hagrid67 committed
526
527
528

        new_direction %= 4

spiglerg's avatar
spiglerg committed
529
        if action == RailEnvActions.MOVE_FORWARD:
hagrid67's avatar
hagrid67 committed
530
531
532
533
534
            if num_transitions == 1:
                # - dead-end, straight line or curved line;
                # new_direction will be the only valid transition
                # - take only available transition
                new_direction = np.argmax(possible_transitions)
535
536
                transition_valid = True
        return new_direction, transition_valid
hagrid67's avatar
hagrid67 committed
537

538
    def _get_observations(self):
539
        self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
540
        return self.obs_dict
541

maljx's avatar
maljx committed
542
543
544
545
    def get_full_state_msg(self):
        grid_data = self.rail.grid.tolist()
        agent_static_data = [agent.to_list() for agent in self.agents_static]
        agent_data = [agent.to_list() for agent in self.agents]
546
547
548
        msgpack.packb(grid_data, use_bin_type=True)
        msgpack.packb(agent_data, use_bin_type=True)
        msgpack.packb(agent_static_data, use_bin_type=True)
maljx's avatar
maljx committed
549
550
551
        msg_data = {
            "grid": grid_data,
            "agents_static": agent_static_data,
552
            "agents": agent_data}
maljx's avatar
maljx committed
553
554
555
556
557
        return msgpack.packb(msg_data, use_bin_type=True)

    def get_agent_state_msg(self):
        agent_data = [agent.to_list() for agent in self.agents]
        msg_data = {
558
            "agents": agent_data}
maljx's avatar
maljx committed
559
560
561
        return msgpack.packb(msg_data, use_bin_type=True)

    def set_full_state_msg(self, msg_data):
562
563
        data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
        self.rail.grid = np.array(data["grid"])
spiglerg's avatar
fix?    
spiglerg committed
564
        # agents are always reset as not moving
565
566
        self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]]
        self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]]
maljx's avatar
maljx committed
567
568
        # setup with loaded data
        self.height, self.width = self.rail.grid.shape
569
570
        self.rail.height = self.height
        self.rail.width = self.width
maljx's avatar
maljx committed
571
572
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)

Erik Nygren's avatar
Erik Nygren committed
573
    def set_full_state_dist_msg(self, msg_data):
574
575
        data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
        self.rail.grid = np.array(data["grid"])
Erik Nygren's avatar
Erik Nygren committed
576
        # agents are always reset as not moving
577
578
        self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]]
        self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]]
579
        if "distance_map" in data.keys():
580
            self.distance_map.set(data["distance_map"])
Erik Nygren's avatar
Erik Nygren committed
581
582
583
584
585
586
587
588
589
590
        # setup with loaded data
        self.height, self.width = self.rail.grid.shape
        self.rail.height = self.height
        self.rail.width = self.width
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)

    def get_full_state_dist_msg(self):
        grid_data = self.rail.grid.tolist()
        agent_static_data = [agent.to_list() for agent in self.agents_static]
        agent_data = [agent.to_list() for agent in self.agents]
591
592
593
        msgpack.packb(grid_data, use_bin_type=True)
        msgpack.packb(agent_data, use_bin_type=True)
        msgpack.packb(agent_static_data, use_bin_type=True)
594
        distance_map_data = self.distance_map.get()
595
596
597
598
599
        msgpack.packb(distance_map_data, use_bin_type=True)
        msg_data = {
            "grid": grid_data,
            "agents_static": agent_static_data,
            "agents": agent_data,
600
            "distance_map": distance_map_data}
Erik Nygren's avatar
Erik Nygren committed
601
602
603

        return msgpack.packb(msg_data, use_bin_type=True)

maljx's avatar
maljx committed
604
    def save(self, filename):
605
606
        if self.distance_map.get() is not None:
            if len(self.distance_map.get()) > 0:
607
608
609
610
611
                with open(filename, "wb") as file_out:
                    file_out.write(self.get_full_state_dist_msg())
            else:
                with open(filename, "wb") as file_out:
                    file_out.write(self.get_full_state_msg())
612
613
614
        else:
            with open(filename, "wb") as file_out:
                file_out.write(self.get_full_state_msg())
maljx's avatar
maljx committed
615
616

    def load(self, filename):
617
618
619
        with open(filename, "rb") as file_in:
            load_data = file_in.read()
            self.set_full_state_dist_msg(load_data)
u214892's avatar
u214892 committed
620

621
622
623
    def load_pkl(self, pkl_data):
        self.set_full_state_msg(pkl_data)

u214892's avatar
u214892 committed
624
625
626
627
    def load_resource(self, package, resource):
        from importlib_resources import read_binary
        load_data = read_binary(package, resource)
        self.set_full_state_msg(load_data)
628
629

    def compute_distance_map(self):
630
        self.distance_map.compute(self.agents, self.rail)
631
        # Update local lookup table for all agents' target locations
632
        self.obs_builder.location_has_target = {tuple(agent.target): 1 for agent in self.agents}
633