rail_env.py 29.6 KB
Newer Older
1
"""
2
Definition of the RailEnv environment.
3
"""
hagrid67's avatar
hagrid67 committed
4
# TODO:  _ this is a global method --> utils or remove later
u214892's avatar
u214892 committed
5
import warnings
6
from enum import IntEnum
u214892's avatar
u214892 committed
7
from typing import List, NamedTuple, Optional, Tuple, Dict
8

maljx's avatar
maljx committed
9
import msgpack
10
import msgpack_numpy as m
11
import numpy as np
12
13

from flatland.core.env import Environment
14
from flatland.core.env_observation_builder import ObservationBuilder
u214892's avatar
u214892 committed
15
from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
u214892's avatar
u214892 committed
16
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
17
from flatland.core.transition_map import GridTransitionMap
u214892's avatar
u214892 committed
18
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus
19
from flatland.envs.distance_map import DistanceMap
20
from flatland.envs.observations import TreeObsForRailEnv
u214892's avatar
u214892 committed
21
from flatland.envs.rail_generators import random_rail_generator, RailGenerator
22
from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
23

24
25
m.patch()

26

spiglerg's avatar
spiglerg committed
27
class RailEnvActions(IntEnum):
28
    DO_NOTHING = 0  # implies change of direction in a dead-end!
spiglerg's avatar
spiglerg committed
29
30
31
32
33
    MOVE_LEFT = 1
    MOVE_FORWARD = 2
    MOVE_RIGHT = 3
    STOP_MOVING = 4

34
35
36
37
38
39
40
41
42
43
    @staticmethod
    def to_char(a: int):
        return {
            0: 'B',
            1: 'L',
            2: 'F',
            3: 'R',
            4: 'S',
        }[a]

u214892's avatar
u214892 committed
44

u214892's avatar
u214892 committed
45
46
47
48
49
RailEnvGridPos = NamedTuple('RailEnvGridPos', [('r', int), ('c', int)])
RailEnvNextAction = NamedTuple('RailEnvNextAction', [('action', RailEnvActions), ('next_position', RailEnvGridPos),
                                                     ('next_direction', Grid4TransitionsEnum)])


50
51
52
53
54
55
56
57
58
59
class RailEnv(Environment):
    """
    RailEnv environment class.

    RailEnv is an environment inspired by a (simplified version of) a rail
    network, in which agents (trains) have to navigate to their target
    locations in the shortest time possible, while at the same time cooperating
    to avoid bottlenecks.

    The valid actions in the environment are:
60
61
62
63
64
65

     -   0: do nothing (continue moving or stay still)
     -   1: turn left at switch and move to the next cell; if the agent was not moving, movement is started
     -   2: move to the next cell in front of the agent; if the agent was not moving, movement is started
     -   3: turn right at switch and move to the next cell; if the agent was not moving, movement is started
     -   4: stop moving
66
67
68
69

    Moving forward in a dead-end cell makes the agent turn 180 degrees and step
    to the cell it came from.

70

71
72
73
    The actions of the agents are executed in order of their handle to prevent
    deadlocks and to allow them to learn relative priorities.

74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    Reward Function:

    It costs each agent a step_penalty for every time-step taken in the environment. Independent of the movement
    of the agent. Currently all other penalties such as penalty for stopping, starting and invalid actions are set to 0.

    alpha = 1
    beta = 1
    Reward function parameters:

    - invalid_action_penalty = 0
    - step_penalty = -alpha
    - global_reward = beta
    - stop_penalty = 0  # penalty for stopping a moving agent
    - start_penalty = 0  # penalty for starting a stopped agent

89
90
    Stochastic malfunctioning of trains:
    Trains in RailEnv can malfunction if they are halted too often (either by their own choice or because an invalid
spiglerg's avatar
spiglerg committed
91
92
    action or cell is selected.

93
94
95
    Every time an agent stops, an agent has a certain probability of malfunctioning. Malfunctions of trains follow a
    poisson process with a certain rate. Not all trains will be affected by malfunctions during episodes to keep
    complexity managable.
spiglerg's avatar
spiglerg committed
96
97
98
99

    TODO: currently, the parameters that control the stochasticity of the environment are hard-coded in init().
    For Round 2, they will be passed to the constructor as arguments, to allow for more flexibility.

100
    """
u214892's avatar
u214892 committed
101
102
103
104
105
106
107
108
109
    alpha = 1.0
    beta = 1.0
    # Epsilon to avoid rounding errors
    epsilon = 0.01
    invalid_action_penalty = 0  # previously -2; GIACOMO: we decided that invalid actions will carry no penalty
    step_penalty = -1 * alpha
    global_reward = 1 * beta
    stop_penalty = 0  # penalty for stopping a moving agent
    start_penalty = 0  # penalty for starting a stopped agent
110
111
112
113

    def __init__(self,
                 width,
                 height,
u214892's avatar
u214892 committed
114
                 rail_generator: RailGenerator = random_rail_generator(),
115
                 schedule_generator: ScheduleGenerator = random_schedule_generator(),
116
                 number_of_agents=1,
117
                 obs_builder_object: ObservationBuilder = TreeObsForRailEnv(max_depth=2),
118
119
                 max_episode_steps=None,
                 stochastic_data=None
u214892's avatar
u214892 committed
120
                 ):
121
122
123
124
        """
        Environment init.

        Parameters
125
        ----------
126
        rail_generator : function
127
128
129
130
            The rail_generator function is a function that takes the width,
            height and agents handles of a  rail environment, along with the number of times
            the env has been reset, and returns a GridTransitionMap object and a list of
            starting positions, targets, and initial orientations for agent handle.
131
            The rail_generator can pass a distance map in the hints or information for specific schedule_generators.
u214892's avatar
u214892 committed
132
            Implementations can be found in flatland/envs/rail_generators.py
133
134
        schedule_generator : function
            The schedule_generator function is a function that takes the grid, the number of agents and optional hints
u214892's avatar
u214892 committed
135
            and returns a list of starting positions, targets, initial orientations and speed for all agent handles.
u214892's avatar
u214892 committed
136
            Implementations can be found in flatland/envs/schedule_generators.py
137
138
139
140
141
142
143
144
145
146
147
148
        width : int
            The width of the rail map. Potentially in the future,
            a range of widths to sample from.
        height : int
            The height of the rail map. Potentially in the future,
            a range of heights to sample from.
        number_of_agents : int
            Number of agents to spawn on the map. Potentially in the future,
            a range of number of agents to sample from.
        obs_builder_object: ObservationBuilder object
            ObservationBuilder-derived object that takes builds observation
            vectors for each agent.
spiglerg's avatar
spiglerg committed
149
        max_episode_steps : int or None
150
        """
151
        super().__init__()
152

u214892's avatar
u214892 committed
153
        self.rail_generator: RailGenerator = rail_generator
154
        self.schedule_generator: ScheduleGenerator = schedule_generator
155
        self.rail_generator = rail_generator
u214892's avatar
u214892 committed
156
        self.rail: Optional[GridTransitionMap] = None
157
158
159
        self.width = width
        self.height = height

160
        self.rewards = [0] * number_of_agents
161
        self.done = False
162
        self.obs_builder = obs_builder_object
u229589's avatar
u229589 committed
163
        self.obs_builder.set_env(self)
164

spiglerg's avatar
spiglerg committed
165
166
167
        self._max_episode_steps = max_episode_steps
        self._elapsed_steps = 0

168
169
        self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False)

170
171
        self.obs_dict = {}
        self.rewards_dict = {}
Erik Nygren's avatar
Erik Nygren committed
172
        self.dev_obs_dict = {}
Erik Nygren's avatar
Erik Nygren committed
173
        self.dev_pred_dict = {}
174

u214892's avatar
u214892 committed
175
176
        self.agents: List[EnvAgent] = [None] * number_of_agents  # live agents
        self.agents_static: List[EnvAgentStatic] = [None] * number_of_agents  # static agent information
177
        self.num_resets = 0
178
        self.distance_map = DistanceMap(self.agents, self.height, self.width)
179
180
181

        self.action_space = [1]

182
        # Stochastic train malfunctioning parameters
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        if stochastic_data is not None:
            prop_malfunction = stochastic_data['prop_malfunction']
            mean_malfunction_rate = stochastic_data['malfunction_rate']
            malfunction_min_duration = stochastic_data['min_duration']
            malfunction_max_duration = stochastic_data['max_duration']
        else:
            prop_malfunction = 0.
            mean_malfunction_rate = 0.
            malfunction_min_duration = 0.
            malfunction_max_duration = 0.

        # percentage of malfunctioning trains
        self.proportion_malfunctioning_trains = prop_malfunction

        # Mean malfunction in number of stops
        self.mean_malfunction_rate = mean_malfunction_rate
199
200

        # Uniform distribution parameters for malfunction duration
201
202
        self.min_number_of_steps_broken = malfunction_min_duration
        self.max_number_of_steps_broken = malfunction_max_duration
203
204

        # Rest environment
205
        self.reset()
206
        self.num_resets = 0  # yes, set it to zero again!
207

208
209
        self.valid_positions = None

210
    # no more agent_handles
211
    def get_agent_handles(self):
212
213
214
215
216
217
218
        return range(self.get_num_agents())

    def get_num_agents(self, static=True):
        if static:
            return len(self.agents_static)
        else:
            return len(self.agents)
219

hagrid67's avatar
hagrid67 committed
220
221
222
223
224
225
226
    def add_agent_static(self, agent_static):
        """ Add static info for a single agent.
            Returns the index of the new agent.
        """
        self.agents_static.append(agent_static)
        return len(self.agents_static) - 1

227
228
    def restart_agents(self):
        """ Reset the agents to their starting positions defined in agents_static
hagrid67's avatar
hagrid67 committed
229
        """
230
231
232
233
234
235
        self.agents = EnvAgent.list_from_static(self.agents_static)

    def reset(self, regen_rail=True, replace_agents=True):
        """ if regen_rail then regenerate the rails.
            if replace_agents then regenerate the agents static.
            Relies on the rail_generator returning agent_static lists (pos, dir, target)
hagrid67's avatar
hagrid67 committed
236
        """
237

238
239
        # TODO https://gitlab.aicrowd.com/flatland/flatland/issues/172
        #  can we not put 'self.rail_generator(..)' into 'if regen_rail or self.rail is None' condition?
u214892's avatar
u214892 committed
240
        rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
hagrid67's avatar
hagrid67 committed
241

242
        if regen_rail or self.rail is None:
u214892's avatar
u214892 committed
243
            self.rail = rail
244
            self.height, self.width = self.rail.grid.shape
u214892's avatar
u214892 committed
245
246
            for r in range(self.height):
                for c in range(self.width):
u214892's avatar
u214892 committed
247
248
                    rc_pos = (r, c)
                    check = self.rail.cell_neighbours_valid(rc_pos, True)
u214892's avatar
u214892 committed
249
                    if not check:
u214892's avatar
u214892 committed
250
                        warnings.warn("Invalid grid at {} -> {}".format(rc_pos, check))
251
252
253
254
255
        # TODO https://gitlab.aicrowd.com/flatland/flatland/issues/172
        #  hacky: we must re-compute the distance map and not use the initial distance_map loaded from file by
        #  rail_from_file!!!
        elif optionals and 'distance_map' in optionals:
            self.distance_map.set(optionals['distance_map'])
256

hagrid67's avatar
hagrid67 committed
257
        if replace_agents:
u214892's avatar
u214892 committed
258
259
260
            agents_hints = None
            if optionals and 'agents_hints' in optionals:
                agents_hints = optionals['agents_hints']
261
262
263

            # TODO https://gitlab.aicrowd.com/flatland/flatland/issues/185
            #  why do we need static agents? could we it more elegantly?
u214892's avatar
u214892 committed
264
            self.agents_static = EnvAgentStatic.from_lists(
u214892's avatar
u214892 committed
265
                *self.schedule_generator(self.rail, self.get_num_agents(), agents_hints))
266
        self.restart_agents()
hagrid67's avatar
hagrid67 committed
267

268
269
        for i_agent in range(self.get_num_agents()):
            agent = self.agents[i_agent]
270
271

            # A proportion of agent in the environment will receive a positive malfunction rate
272
            if np.random.random() < self.proportion_malfunctioning_trains:
273
                agent.malfunction_data['malfunction_rate'] = self.mean_malfunction_rate
274

275
            agent.malfunction_data['malfunction'] = 0
Erik Nygren's avatar
Erik Nygren committed
276

277
278
            initial_malfunction = self._agent_malfunction(i_agent)

u214892's avatar
u214892 committed
279
            if initial_malfunction:
280
                agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING
Erik Nygren's avatar
Erik Nygren committed
281

282
        self.num_resets += 1
spiglerg's avatar
spiglerg committed
283
        self._elapsed_steps = 0
284

u214892's avatar
u214892 committed
285
        # TODO perhaps dones should be part of each agent.
286
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
287

288
289
        # Reset the state of the observation builder with the new environment
        self.obs_builder.reset()
290
        self.distance_map.reset(self.agents, self.rail)
291
292
293
294

        # Return the new observation vectors for each agent
        return self._get_observations()

295
    def _agent_malfunction(self, i_agent) -> bool:
u214892's avatar
u214892 committed
296
297
298
        """
        Returns true if the agent enters into malfunction. (False, if not broken down or already broken down before).
        """
u214892's avatar
u214892 committed
299
300
        agent = self.agents[i_agent]

301
        # Decrease counter for next event
302
        if agent.malfunction_data['malfunction_rate'] > 0 and agent.malfunction_data['next_malfunction'] > 0:
303
            agent.malfunction_data['next_malfunction'] -= 1
304

305
        # Only agents that have a positive rate for malfunctions and are not currently broken are considered
u214892's avatar
u214892 committed
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        # If counter has come to zero --> Agent has malfunction
        # set next malfunction time and duration of current malfunction
        if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data['malfunction'] and \
            agent.malfunction_data['next_malfunction'] <= 0:
            # Increase number of malfunctions
            agent.malfunction_data['nr_malfunctions'] += 1

            # Next malfunction in number of stops
            next_breakdown = int(
                np.random.exponential(scale=agent.malfunction_data['malfunction_rate']))
            agent.malfunction_data['next_malfunction'] = next_breakdown

            # Duration of current malfunction
            num_broken_steps = np.random.randint(self.min_number_of_steps_broken,
                                                 self.max_number_of_steps_broken + 1) + 1
            agent.malfunction_data['malfunction'] = num_broken_steps
322
            agent.malfunction_data['moving_before_malfunction'] = agent.moving
u214892's avatar
u214892 committed
323
324

            return True
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
        else:
            # The train was broken before...
            if agent.malfunction_data['malfunction'] > 0:

                # Last step of malfunction --> Agent starts moving again after getting fixed
                if agent.malfunction_data['malfunction'] < 2:
                    agent.malfunction_data['malfunction'] -= 1

                    # restore moving state before malfunction without further penalty
                    self.agents[i_agent].moving = agent.malfunction_data['moving_before_malfunction']

                else:
                    agent.malfunction_data['malfunction'] -= 1

                    # Nothing left to do with broken agent
                    return True
u214892's avatar
u214892 committed
341
        return False
u214892's avatar
u214892 committed
342

343
    def step(self, action_dict_: Dict[int, RailEnvActions]):
spiglerg's avatar
spiglerg committed
344
345
        self._elapsed_steps += 1

346
        # Reset the step rewards
347
        self.rewards_dict = dict()
u214892's avatar
u214892 committed
348
349
        for i_agent in range(self.get_num_agents()):
            self.rewards_dict[i_agent] = 0
350

351
        # If we're done, set reward and info_dict and step() is done.
352
        if self.dones["__all__"]:
353
            self.rewards_dict = {i: self.global_reward for i in range(self.get_num_agents())}
u214892's avatar
u214892 committed
354
            info_dict = {
355
                'action_required': {i: False for i in range(self.get_num_agents())},
356
                'malfunction': {i: 0 for i in range(self.get_num_agents())},
u214892's avatar
u214892 committed
357
358
                'speed': {i: 0 for i in range(self.get_num_agents())},
                'status': {i: agent.status for i, agent in enumerate(self.agents)}
u214892's avatar
u214892 committed
359
360
            }
            return self._get_observations(), self.rewards_dict, self.dones, info_dict
361

362
        # Perform step on all agents
363
        for i_agent in range(self.get_num_agents()):
364
            self._step_agent(i_agent, action_dict_.get(i_agent))
spiglerg's avatar
spiglerg committed
365

366
367
        # Check for end of episode + set global reward to all rewards!
        if np.all([np.array_equal(agent.position, agent.target) for agent in self.agents]):
368
            self.dones["__all__"] = True
369
            self.rewards_dict = {i: self.global_reward for i in range(self.get_num_agents())}
370

spiglerg's avatar
spiglerg committed
371
372
        if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps):
            self.dones["__all__"] = True
u214892's avatar
u214892 committed
373
374
375
            for i in range(self.get_num_agents()):
                self.agents[i].status = RailAgentStatus.DONE
                self.dones[i] = True
376

u214892's avatar
u214892 committed
377
        info_dict = {
u214892's avatar
u214892 committed
378
379
380
381
382
383
384
            'action_required': {i: self.agents[i].speed_data['position_fraction'] == 0.0 for i in
                                range(self.get_num_agents())},
            'malfunction': {
                i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents())
            },
            'speed': {i: self.agents[i].speed_data['speed'] for i in range(self.get_num_agents())},
            'status': {i: agent.status for i, agent in enumerate(self.agents)}
u214892's avatar
u214892 committed
385
386
387
        }

        return self._get_observations(), self.rewards_dict, self.dones, info_dict
388

389
    def _step_agent(self, i_agent, action: Optional[RailEnvActions] = None):
390
391
392
393
394
        """
        Performs a step and step, start and stop penalty on a single agent in the following sub steps:
        - malfunction
        - action handling if at the beginning of cell
        - movement
395

396
397
398
399
400
401
        Parameters
        ----------
        i_agent : int
        action_dict_ : Dict[int,RailEnvActions]

        """
u214892's avatar
u214892 committed
402
403
        agent = self.agents[i_agent]
        if agent.status == RailAgentStatus.DONE:  # this agent has already completed...
404
405
            return

u214892's avatar
u214892 committed
406
407
408
409
410
411
412
413
        # agent gets active by a MOVE_* action and if c
        if agent.status == RailAgentStatus.READY_TO_DEPART:
            if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT,
                          RailEnvActions.MOVE_FORWARD]:  # and self.cell_free(agent.position):
                agent.status = RailAgentStatus.ACTIVE
            else:
                return

414
415
416
417
418
419
        agent.old_direction = agent.direction
        agent.old_position = agent.position

        # is the agent malfunctioning?
        malfunction = self._agent_malfunction(i_agent)

420
421
        # if agent is broken, actions are ignored and agent does not move.
        # full step penalty in this case
422
423
424
425
426
        if malfunction:
            self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
            return

        # Is the agent at the beginning of the cell? Then, it can take an action.
427
        # As long as the agent is malfunctioning or stopped at the beginning of the cell, different actions may be taken!
428
429
        if agent.speed_data['position_fraction'] == 0.0:
            # No action has been supplied for this agent -> set DO_NOTHING as default
430
            if action is None:
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
                action = RailEnvActions.DO_NOTHING

            if action < 0 or action > len(RailEnvActions):
                print('ERROR: illegal action=', action,
                      'for agent with index=', i_agent,
                      '"DO NOTHING" will be executed instead')
                action = RailEnvActions.DO_NOTHING

            if action == RailEnvActions.DO_NOTHING and agent.moving:
                # Keep moving
                action = RailEnvActions.MOVE_FORWARD

            if action == RailEnvActions.STOP_MOVING and agent.moving:
                # Only allow halting an agent on entering new cells.
                agent.moving = False
                self.rewards_dict[i_agent] += self.stop_penalty

            if not agent.moving and not (
                action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
                # Allow agent to start with any forward or direction action
                agent.moving = True
                self.rewards_dict[i_agent] += self.start_penalty

            # Store the action if action is moving
            # If not moving, the action will be stored when the agent starts moving again.
            if agent.moving:
                _action_stored = False
                _, new_cell_valid, new_direction, new_position, transition_valid = \
                    self._check_action_on_agent(action, agent)

                if all([new_cell_valid, transition_valid]):
                    agent.speed_data['transition_action_on_cellexit'] = action
                    _action_stored = True
                else:
                    # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
                    # try to keep moving forward!
                    if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT):
                        _, new_cell_valid, new_direction, new_position, transition_valid = \
                            self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)

                        if all([new_cell_valid, transition_valid]):
                            agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
                            _action_stored = True

                if not _action_stored:
                    # If the agent cannot move due to an invalid transition, we set its state to not moving
                    self.rewards_dict[i_agent] += self.invalid_action_penalty
                    self.rewards_dict[i_agent] += self.stop_penalty
                    agent.moving = False

        # Now perform a movement.
        # If agent.moving, increment the position_fraction by the speed of the agent
        # If the new position fraction is >= 1, reset to 0, and perform the stored
        #   transition_action_on_cellexit if the cell is free.
        if agent.moving:
            agent.speed_data['position_fraction'] += agent.speed_data['speed']
            if agent.speed_data['position_fraction'] >= 1.0:
                # Perform stored action to transition to the next cell as soon as cell is free
                # Notice that we've already checked new_cell_valid and transition valid when we stored the action,
                # so we only have to check cell_free now!

                # cell and transition validity was checked when we stored transition_action_on_cellexit!
                cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent(
                    agent.speed_data['transition_action_on_cellexit'], agent)

                # N.B. validity of new_cell and transition should have been verified before the action was stored!
                assert new_cell_valid
                assert transition_valid
                if cell_free:
                    agent.position = new_position
                    agent.direction = new_direction
                    agent.speed_data['position_fraction'] = 0.0

            # has the agent reached its target?
            if np.equal(agent.position, agent.target).all():
u214892's avatar
u214892 committed
506
                agent.status = RailAgentStatus.DONE
507
508
509
510
                self.dones[i_agent] = True
                agent.moving = False
            else:
                self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
511
512
513
        else:
            # step penalty if not moving (stopped now or before)
            self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530

    def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent):
        """

        Parameters
        ----------
        action : RailEnvActions
        agent : EnvAgent

        Returns
        -------
        bool
            Is it a legal move?
            1) transition allows the new_direction in the cell,
            2) the new cell is not empty (case 0),
            3) the cell is free, i.e., no agent is currently in that cell

u214892's avatar
u214892 committed
531

532
        """
u214892's avatar
u214892 committed
533
534
        # compute number of possible transitions in the current
        # cell used to check for invalid actions
535
        new_direction, transition_valid = self.check_action(agent, action)
u214892's avatar
u214892 committed
536
        new_position = get_new_position(agent.position, new_direction)
537

538
        new_cell_valid = (
spiglerg's avatar
spiglerg committed
539
540
541
542
            np.array_equal(  # Check the new position is still in the grid
                new_position,
                np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
            and  # check the new position has some transitions (ie is not an empty cell)
u214892's avatar
u214892 committed
543
            self.rail.get_full_transitions(*new_position) > 0)
544

spiglerg's avatar
spiglerg committed
545
        # If transition validity hasn't been checked yet.
546
547
        if transition_valid is None:
            transition_valid = self.rail.get_transition(
spiglerg's avatar
spiglerg committed
548
549
                (*agent.position, agent.direction),
                new_direction)
550

spiglerg's avatar
spiglerg committed
551
552
        # Check the new position is not the same as any of the existing agent positions
        # (including itself, for simplicity, since it is moving)
u214892's avatar
u214892 committed
553
        cell_free = self.cell_free(new_position)
554
        return cell_free, new_cell_valid, new_direction, new_position, transition_valid
spiglerg's avatar
spiglerg committed
555

u214892's avatar
u214892 committed
556
557
558
    def cell_free(self, position):
        return not np.any(np.equal(position, [agent.position for agent in self.agents]).all(1))

559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
    def check_action(self, agent: EnvAgent, action: RailEnvActions):
        """

        Parameters
        ----------
        agent : EnvAgent
        action : RailEnvActions

        Returns
        -------
        Tuple[Grid4TransitionsEnum,Tuple[int,int]]



        """
574
        transition_valid = None
u214892's avatar
u214892 committed
575
        possible_transitions = self.rail.get_transitions(*agent.position, agent.direction)
hagrid67's avatar
hagrid67 committed
576
577
578
        num_transitions = np.count_nonzero(possible_transitions)

        new_direction = agent.direction
spiglerg's avatar
spiglerg committed
579
        if action == RailEnvActions.MOVE_LEFT:
hagrid67's avatar
hagrid67 committed
580
581
            new_direction = agent.direction - 1
            if num_transitions <= 1:
582
                transition_valid = False
hagrid67's avatar
hagrid67 committed
583

spiglerg's avatar
spiglerg committed
584
        elif action == RailEnvActions.MOVE_RIGHT:
hagrid67's avatar
hagrid67 committed
585
586
            new_direction = agent.direction + 1
            if num_transitions <= 1:
587
                transition_valid = False
hagrid67's avatar
hagrid67 committed
588
589
590

        new_direction %= 4

591
592
593
594
595
596
        if action == RailEnvActions.MOVE_FORWARD and num_transitions == 1:
            # - dead-end, straight line or curved line;
            # new_direction will be the only valid transition
            # - take only available transition
            new_direction = np.argmax(possible_transitions)
            transition_valid = True
597
        return new_direction, transition_valid
hagrid67's avatar
hagrid67 committed
598

599
    def _get_observations(self):
600
        self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
601
        return self.obs_dict
602

u214892's avatar
u214892 committed
603
    def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]:
u214892's avatar
u214892 committed
604
        return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col))
u214892's avatar
u214892 committed
605

maljx's avatar
maljx committed
606
607
608
609
    def get_full_state_msg(self):
        grid_data = self.rail.grid.tolist()
        agent_static_data = [agent.to_list() for agent in self.agents_static]
        agent_data = [agent.to_list() for agent in self.agents]
610
611
612
        msgpack.packb(grid_data, use_bin_type=True)
        msgpack.packb(agent_data, use_bin_type=True)
        msgpack.packb(agent_static_data, use_bin_type=True)
maljx's avatar
maljx committed
613
614
615
        msg_data = {
            "grid": grid_data,
            "agents_static": agent_static_data,
616
            "agents": agent_data}
maljx's avatar
maljx committed
617
618
619
620
621
        return msgpack.packb(msg_data, use_bin_type=True)

    def get_agent_state_msg(self):
        agent_data = [agent.to_list() for agent in self.agents]
        msg_data = {
622
            "agents": agent_data}
maljx's avatar
maljx committed
623
624
625
        return msgpack.packb(msg_data, use_bin_type=True)

    def set_full_state_msg(self, msg_data):
626
627
        data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
        self.rail.grid = np.array(data["grid"])
spiglerg's avatar
fix?    
spiglerg committed
628
        # agents are always reset as not moving
629
630
        self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]]
        self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]]
maljx's avatar
maljx committed
631
632
        # setup with loaded data
        self.height, self.width = self.rail.grid.shape
633
634
        self.rail.height = self.height
        self.rail.width = self.width
maljx's avatar
maljx committed
635
636
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)

Erik Nygren's avatar
Erik Nygren committed
637
    def set_full_state_dist_msg(self, msg_data):
638
639
        data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
        self.rail.grid = np.array(data["grid"])
Erik Nygren's avatar
Erik Nygren committed
640
        # agents are always reset as not moving
641
642
        self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]]
        self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]]
643
        if "distance_map" in data.keys():
644
            self.distance_map.set(data["distance_map"])
Erik Nygren's avatar
Erik Nygren committed
645
646
647
648
649
650
651
652
653
654
        # setup with loaded data
        self.height, self.width = self.rail.grid.shape
        self.rail.height = self.height
        self.rail.width = self.width
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)

    def get_full_state_dist_msg(self):
        grid_data = self.rail.grid.tolist()
        agent_static_data = [agent.to_list() for agent in self.agents_static]
        agent_data = [agent.to_list() for agent in self.agents]
655
656
657
        msgpack.packb(grid_data, use_bin_type=True)
        msgpack.packb(agent_data, use_bin_type=True)
        msgpack.packb(agent_static_data, use_bin_type=True)
658
        distance_map_data = self.distance_map.get()
659
660
661
662
663
        msgpack.packb(distance_map_data, use_bin_type=True)
        msg_data = {
            "grid": grid_data,
            "agents_static": agent_static_data,
            "agents": agent_data,
664
            "distance_map": distance_map_data}
Erik Nygren's avatar
Erik Nygren committed
665
666
667

        return msgpack.packb(msg_data, use_bin_type=True)

maljx's avatar
maljx committed
668
    def save(self, filename):
669
670
        if self.distance_map.get() is not None:
            if len(self.distance_map.get()) > 0:
671
672
673
674
675
                with open(filename, "wb") as file_out:
                    file_out.write(self.get_full_state_dist_msg())
            else:
                with open(filename, "wb") as file_out:
                    file_out.write(self.get_full_state_msg())
676
677
678
        else:
            with open(filename, "wb") as file_out:
                file_out.write(self.get_full_state_msg())
maljx's avatar
maljx committed
679
680

    def load(self, filename):
681
682
683
        with open(filename, "rb") as file_in:
            load_data = file_in.read()
            self.set_full_state_dist_msg(load_data)
u214892's avatar
u214892 committed
684

685
686
687
    def load_pkl(self, pkl_data):
        self.set_full_state_msg(pkl_data)

u214892's avatar
u214892 committed
688
689
690
691
    def load_resource(self, package, resource):
        from importlib_resources import read_binary
        load_data = read_binary(package, resource)
        self.set_full_state_msg(load_data)