rail_env.py 46.5 KB
Newer Older
1
"""
2
Definition of the RailEnv environment.
3
"""
4
import random
hagrid67's avatar
hagrid67 committed
5
# TODO:  _ this is a global method --> utils or remove later
6
from enum import IntEnum
7
from typing import List, NamedTuple, Optional, Dict, Tuple
8

9
import numpy as np
10

11

12
from flatland.core.env import Environment
13
from flatland.core.env_observation_builder import ObservationBuilder
u214892's avatar
u214892 committed
14
from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
u214892's avatar
u214892 committed
15
from flatland.core.grid.grid4_utils import get_new_position
16
from flatland.core.grid.grid_utils import IntVector2D
u214892's avatar
u214892 committed
17
from flatland.core.transition_map import GridTransitionMap
u229589's avatar
u229589 committed
18
from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
19
from flatland.envs.distance_map import DistanceMap
hagrid67's avatar
hagrid67 committed
20 21

# Need to use circular imports for persistence.
22 23 24
from flatland.envs import malfunction_generators as mal_gen
from flatland.envs import rail_generators as rail_gen
from flatland.envs import schedule_generators as sched_gen
hagrid67's avatar
hagrid67 committed
25
from flatland.envs import persistence
26
from flatland.envs import agent_chains as ac
hagrid67's avatar
hagrid67 committed
27

28 29 30
from flatland.envs.observations import GlobalObsForRailEnv
from gym.utils import seeding

31 32 33 34 35
# Direct import of objects / classes does not work with circular imports.
# from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionProcessData
# from flatland.envs.observations import GlobalObsForRailEnv
# from flatland.envs.rail_generators import random_rail_generator, RailGenerator
# from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
hagrid67's avatar
hagrid67 committed
36 37


38

39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
# Adrian Egli performance fix (the fast methods brings more than 50%)
def fast_isclose(a, b, rtol):
    return (a < (b + rtol)) or (a < (b - rtol))


def fast_clip(position: (int, int), min_value: (int, int), max_value: (int, int)) -> bool:
    return (
        max(min_value[0], min(position[0], max_value[0])),
        max(min_value[1], min(position[1], max_value[1]))
    )


def fast_argmax(possible_transitions: (int, int, int, int)) -> bool:
    if possible_transitions[0] == 1:
        return 0
    if possible_transitions[1] == 1:
        return 1
    if possible_transitions[2] == 1:
        return 2
    return 3


def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
    return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]


def fast_count_nonzero(possible_transitions: (int, int, int, int)):
    return possible_transitions[0] + possible_transitions[1] + possible_transitions[2] + possible_transitions[3]


spiglerg's avatar
spiglerg committed
69
class RailEnvActions(IntEnum):
70
    DO_NOTHING = 0  # implies change of direction in a dead-end!
spiglerg's avatar
spiglerg committed
71 72 73 74 75
    MOVE_LEFT = 1
    MOVE_FORWARD = 2
    MOVE_RIGHT = 3
    STOP_MOVING = 4

76 77 78 79 80 81 82 83 84 85
    @staticmethod
    def to_char(a: int):
        return {
            0: 'B',
            1: 'L',
            2: 'F',
            3: 'R',
            4: 'S',
        }[a]

u214892's avatar
u214892 committed
86

u214892's avatar
u214892 committed
87 88 89 90 91
RailEnvGridPos = NamedTuple('RailEnvGridPos', [('r', int), ('c', int)])
RailEnvNextAction = NamedTuple('RailEnvNextAction', [('action', RailEnvActions), ('next_position', RailEnvGridPos),
                                                     ('next_direction', Grid4TransitionsEnum)])


92 93 94 95 96 97 98 99 100 101
class RailEnv(Environment):
    """
    RailEnv environment class.

    RailEnv is an environment inspired by a (simplified version of) a rail
    network, in which agents (trains) have to navigate to their target
    locations in the shortest time possible, while at the same time cooperating
    to avoid bottlenecks.

    The valid actions in the environment are:
102 103 104 105 106 107

     -   0: do nothing (continue moving or stay still)
     -   1: turn left at switch and move to the next cell; if the agent was not moving, movement is started
     -   2: move to the next cell in front of the agent; if the agent was not moving, movement is started
     -   3: turn right at switch and move to the next cell; if the agent was not moving, movement is started
     -   4: stop moving
108 109 110 111

    Moving forward in a dead-end cell makes the agent turn 180 degrees and step
    to the cell it came from.

112

113 114 115
    The actions of the agents are executed in order of their handle to prevent
    deadlocks and to allow them to learn relative priorities.

116 117 118 119 120 121 122 123 124 125 126 127
    Reward Function:

    It costs each agent a step_penalty for every time-step taken in the environment. Independent of the movement
    of the agent. Currently all other penalties such as penalty for stopping, starting and invalid actions are set to 0.

    alpha = 1
    beta = 1
    Reward function parameters:

    - invalid_action_penalty = 0
    - step_penalty = -alpha
    - global_reward = beta
128
    - epsilon = avoid rounding errors
129 130 131
    - stop_penalty = 0  # penalty for stopping a moving agent
    - start_penalty = 0  # penalty for starting a stopped agent

132 133
    Stochastic malfunctioning of trains:
    Trains in RailEnv can malfunction if they are halted too often (either by their own choice or because an invalid
spiglerg's avatar
spiglerg committed
134 135
    action or cell is selected.

136 137 138
    Every time an agent stops, an agent has a certain probability of malfunctioning. Malfunctions of trains follow a
    poisson process with a certain rate. Not all trains will be affected by malfunctions during episodes to keep
    complexity managable.
spiglerg's avatar
spiglerg committed
139 140 141 142

    TODO: currently, the parameters that control the stochasticity of the environment are hard-coded in init().
    For Round 2, they will be passed to the constructor as arguments, to allow for more flexibility.

143
    """
u214892's avatar
u214892 committed
144 145 146 147 148 149 150 151 152
    alpha = 1.0
    beta = 1.0
    # Epsilon to avoid rounding errors
    epsilon = 0.01
    invalid_action_penalty = 0  # previously -2; GIACOMO: we decided that invalid actions will carry no penalty
    step_penalty = -1 * alpha
    global_reward = 1 * beta
    stop_penalty = 0  # penalty for stopping a moving agent
    start_penalty = 0  # penalty for starting a stopped agent
153 154 155 156

    def __init__(self,
                 width,
                 height,
157 158
                 rail_generator=None,
                 schedule_generator=None,  # : sched_gen.ScheduleGenerator = sched_gen.random_schedule_generator(),
159
                 number_of_agents=1,
Erik Nygren's avatar
Erik Nygren committed
160
                 obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(),
Erik Nygren's avatar
Erik Nygren committed
161
                 malfunction_generator_and_process_data=None,  # mal_gen.no_malfunction_generator(),
162
                 malfunction_generator=None,
163
                 remove_agents_at_target=True,
164
                 random_seed=1,
165 166
                 record_steps=False,
                 close_following=True
u214892's avatar
u214892 committed
167
                 ):
168 169 170 171
        """
        Environment init.

        Parameters
172
        ----------
173
        rail_generator : function
174 175 176 177
            The rail_generator function is a function that takes the width,
            height and agents handles of a  rail environment, along with the number of times
            the env has been reset, and returns a GridTransitionMap object and a list of
            starting positions, targets, and initial orientations for agent handle.
178
            The rail_generator can pass a distance map in the hints or information for specific schedule_generators.
u214892's avatar
u214892 committed
179
            Implementations can be found in flatland/envs/rail_generators.py
180 181
        schedule_generator : function
            The schedule_generator function is a function that takes the grid, the number of agents and optional hints
u214892's avatar
u214892 committed
182
            and returns a list of starting positions, targets, initial orientations and speed for all agent handles.
u214892's avatar
u214892 committed
183
            Implementations can be found in flatland/envs/schedule_generators.py
184 185 186 187 188 189 190 191 192 193 194 195
        width : int
            The width of the rail map. Potentially in the future,
            a range of widths to sample from.
        height : int
            The height of the rail map. Potentially in the future,
            a range of heights to sample from.
        number_of_agents : int
            Number of agents to spawn on the map. Potentially in the future,
            a range of number of agents to sample from.
        obs_builder_object: ObservationBuilder object
            ObservationBuilder-derived object that takes builds observation
            vectors for each agent.
196 197 198
        remove_agents_at_target : bool
            If remove_agents_at_target is set to true then the agents will be removed by placing to
            RailEnv.DEPOT_POSITION when the agent has reach it's target position.
199 200 201
        random_seed : int or None
            if None, then its ignored, else the random generators are seeded with this number to ensure
            that stochastic operations are replicable across multiple operations
202
        """
203
        super().__init__()
204

205 206 207 208 209 210
        if malfunction_generator_and_process_data is not None:
            print("DEPRECATED - RailEnv arg: malfunction_and_process_data - use malfunction_generator")
            self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
        elif malfunction_generator is not None:
            self.malfunction_generator = malfunction_generator
            # malfunction_process_data is not used
Erik Nygren's avatar
Erik Nygren committed
211
            # self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
212 213 214 215 216 217
            self.malfunction_process_data = self.malfunction_generator.get_process_data()
        # replace default values here because we can't use default args values because of cyclic imports
        else:
            self.malfunction_generator = mal_gen.NoMalfunctionGen()
            self.malfunction_process_data = self.malfunction_generator.get_process_data()

Erik Nygren's avatar
Erik Nygren committed
218
        # self.rail_generator: RailGenerator = rail_generator
219 220 221
        if rail_generator is None:
            rail_generator = rail_gen.random_rail_generator()
        self.rail_generator = rail_generator
222
        # self.schedule_generator: ScheduleGenerator = schedule_generator
223 224 225 226
        if schedule_generator is None:
            schedule_generator = sched_gen.random_schedule_generator()
        self.schedule_generator = schedule_generator

u214892's avatar
u214892 committed
227
        self.rail: Optional[GridTransitionMap] = None
228 229
        self.width = width
        self.height = height
Erik Nygren's avatar
Erik Nygren committed
230

231 232
        self.remove_agents_at_target = remove_agents_at_target

Erik Nygren's avatar
Erik Nygren committed
233
        self.rewards = [0] * number_of_agents
234
        self.done = False
235
        self.obs_builder = obs_builder_object
u229589's avatar
u229589 committed
236
        self.obs_builder.set_env(self)
237

238
        self._max_episode_steps: Optional[int] = None
spiglerg's avatar
spiglerg committed
239 240
        self._elapsed_steps = 0

Erik Nygren's avatar
Erik Nygren committed
241
        self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False)
242

243 244
        self.obs_dict = {}
        self.rewards_dict = {}
Erik Nygren's avatar
Erik Nygren committed
245
        self.dev_obs_dict = {}
Erik Nygren's avatar
Erik Nygren committed
246
        self.dev_pred_dict = {}
247

u229589's avatar
u229589 committed
248 249
        self.agents: List[EnvAgent] = []
        self.number_of_agents = number_of_agents
250
        self.num_resets = 0
251
        self.distance_map = DistanceMap(self.agents, self.height, self.width)
252

Erik Nygren's avatar
Erik Nygren committed
253
        self.action_space = [5]
254

255
        self._seed()
256 257 258 259 260
        self._seed()
        self.random_seed = random_seed
        if self.random_seed:
            self._seed(seed=random_seed)

261 262
        self.valid_positions = None

263 264
        # global numpy array of agents position, True means that there is an agent at that cell
        self.agent_positions: np.ndarray = np.full((height, width), False)
265

266 267
        # save episode timesteps ie agent positions, orientations.  (not yet actions / observations)
        self.record_steps = record_steps  # whether to save timesteps
268
        # save timesteps in here: [[[row, col, dir, malfunction],...nAgents], ...nSteps]
269 270
        self.cur_episode = []
        self.list_actions = []  # save actions in here
271

272
        self.close_following = close_following  # use close following logic
273 274
        self.motionCheck = ac.MotionCheck()

275 276
    def _seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
277
        random.seed(seed)
278 279
        return [seed]

280
    # no more agent_handles
281
    def get_agent_handles(self):
282 283
        return range(self.get_num_agents())

u229589's avatar
u229589 committed
284 285
    def get_num_agents(self) -> int:
        return len(self.agents)
286

u229589's avatar
u229589 committed
287
    def add_agent(self, agent):
hagrid67's avatar
hagrid67 committed
288 289 290
        """ Add static info for a single agent.
            Returns the index of the new agent.
        """
u229589's avatar
u229589 committed
291 292
        self.agents.append(agent)
        return len(self.agents) - 1
hagrid67's avatar
hagrid67 committed
293

294
    def set_agent_active(self, agent: EnvAgent):
u214892's avatar
u214892 committed
295 296
        if agent.status == RailAgentStatus.READY_TO_DEPART and self.cell_free(agent.initial_position):
            agent.status = RailAgentStatus.ACTIVE
297
            self._set_agent_to_initial_position(agent, agent.initial_position)
u214892's avatar
u214892 committed
298

Erik Nygren's avatar
Erik Nygren committed
299
    def reset_agents(self):
u229589's avatar
u229589 committed
300
        """ Reset the agents to their starting positions
hagrid67's avatar
hagrid67 committed
301
        """
u229589's avatar
u229589 committed
302 303
        for agent in self.agents:
            agent.reset()
304
        self.active_agents = [i for i in range(len(self.agents))]
Erik Nygren's avatar
Erik Nygren committed
305

Erik Nygren's avatar
Erik Nygren committed
306 307 308 309 310 311 312 313 314 315 316 317 318 319 320
    def action_required(self, agent):
        """
        Check if an agent needs to provide an action

        Parameters
        ----------
        agent: RailEnvAgent
        Agent we want to check

        Returns
        -------
        True: Agent needs to provide an action
        False: Agent cannot provide an action
        """
        return (agent.status == RailAgentStatus.READY_TO_DEPART or (
321 322
            agent.status == RailAgentStatus.ACTIVE and fast_isclose(agent.speed_data['position_fraction'], 0.0,
                                                                    rtol=1e-03)))
Erik Nygren's avatar
Erik Nygren committed
323

324
    def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False,
325
              random_seed: bool = None) -> Tuple[Dict, Dict]:
326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347
        """
        reset(regenerate_rail, regenerate_schedule, activate_agents, random_seed)

        The method resets the rail environment

        Parameters
        ----------
        regenerate_rail : bool, optional
            regenerate the rails
        regenerate_schedule : bool, optional
            regenerate the schedule and the static agents
        activate_agents : bool, optional
            activate the agents
        random_seed : bool, optional
            random seed for environment

        Returns
        -------
        observation_dict: Dict
            Dictionary with an observation for each agent
        info_dict: Dict with agent specific information

hagrid67's avatar
hagrid67 committed
348
        """
349

350 351
        if random_seed:
            self._seed(random_seed)
352

353
        optionals = {}
354
        if regenerate_rail or self.rail is None:
355 356 357 358 359 360 361 362 363 364

            if "__call__" in dir(self.rail_generator):
                rail, optionals = self.rail_generator(
                    self.width, self.height, self.number_of_agents, self.num_resets, self.np_random)
            elif "generate" in dir(self.rail_generator):
                rail, optionals = self.rail_generator.generate(
                    self.width, self.height, self.number_of_agents, self.num_resets, self.np_random)
            else:
                raise ValueError("Could not invoke __call__ or generate on rail_generator")

u214892's avatar
u214892 committed
365
            self.rail = rail
366
            self.height, self.width = self.rail.grid.shape
367

spmohanty's avatar
spmohanty committed
368
            # Do a new set_env call on the obs_builder to ensure
369
            # that obs_builder specific instantiations are made according to the
spmohanty's avatar
spmohanty committed
370 371 372
            # specifications of the current environment : like width, height, etc
            self.obs_builder.set_env(self)

373
        if optionals and 'distance_map' in optionals:
374
            self.distance_map.set(optionals['distance_map'])
375

376
        if regenerate_schedule or regenerate_rail or self.get_num_agents() == 0:
u214892's avatar
u214892 committed
377 378 379
            agents_hints = None
            if optionals and 'agents_hints' in optionals:
                agents_hints = optionals['agents_hints']
380

381 382
            schedule = self.schedule_generator(self.rail, self.number_of_agents, agents_hints, self.num_resets,
                                               self.np_random)
u229589's avatar
u229589 committed
383
            self.agents = EnvAgent.from_schedule(schedule)
384

385
            # Get max number of allowed time steps from schedule generator
Erik Nygren's avatar
Erik Nygren committed
386
            # Look at the specific schedule generator used to see where this number comes from
387
            self._max_episode_steps = schedule.max_episode_steps
388

389
        self.agent_positions = np.zeros((self.height, self.width), dtype=int) - 1
390

Erik Nygren's avatar
Erik Nygren committed
391 392
        # Reset agents to initial
        self.reset_agents()
hagrid67's avatar
hagrid67 committed
393

394
        for agent in self.agents:
395
            # Induce malfunctions
396 397 398
            if activate_agents:
                self.set_agent_active(agent)

Erik Nygren's avatar
Erik Nygren committed
399
            self._break_agent(agent)
400

401
            if agent.malfunction_data["malfunction"] > 0:
402
                agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING
Erik Nygren's avatar
Erik Nygren committed
403

404
            # Fix agents that finished their malfunction
405
            self._fix_agent_after_malfunction(agent)
406

407
        self.num_resets += 1
spiglerg's avatar
spiglerg committed
408
        self._elapsed_steps = 0
409

u214892's avatar
u214892 committed
410
        # TODO perhaps dones should be part of each agent.
411
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
412

413 414
        # Reset the state of the observation builder with the new environment
        self.obs_builder.reset()
415
        self.distance_map.reset(self.agents, self.rail)
416

Erik Nygren's avatar
Erik Nygren committed
417
        # Reset the malfunction generator
418 419 420 421
        if "generate" in dir(self.malfunction_generator):
            self.malfunction_generator.generate(reset=True)
        else:
            self.malfunction_generator(reset=True)
Erik Nygren's avatar
Erik Nygren committed
422

423 424 425
        # Empty the episode store of agent positions
        self.cur_episode = []

426
        info_dict: Dict = {
Erik Nygren's avatar
Erik Nygren committed
427
            'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)},
Erik Nygren's avatar
Erik Nygren committed
428
            'malfunction': {
u229589's avatar
u229589 committed
429
                i: agent.malfunction_data['malfunction'] for i, agent in enumerate(self.agents)
Erik Nygren's avatar
Erik Nygren committed
430
            },
u229589's avatar
u229589 committed
431
            'speed': {i: agent.speed_data['speed'] for i, agent in enumerate(self.agents)},
Erik Nygren's avatar
Erik Nygren committed
432 433
            'status': {i: agent.status for i, agent in enumerate(self.agents)}
        }
434
        # Return the new observation vectors for each agent
435 436
        observation_dict: Dict = self._get_observations()
        return observation_dict, info_dict
437

438
    def _fix_agent_after_malfunction(self, agent: EnvAgent):
u214892's avatar
u214892 committed
439
        """
440
        Updates agent malfunction variables and fixes broken agents
u214892's avatar
u214892 committed
441

442 443 444 445
        Parameters
        ----------
        agent
        """
446

447
        # Ignore agents that are OK
448
        if self._is_agent_ok(agent):
449
            return
450

451 452
        # Reduce number of malfunction steps left
        if agent.malfunction_data['malfunction'] > 1:
453
            agent.malfunction_data['malfunction'] -= 1
454
            return
455

456 457 458 459 460
        # Restart agents at the end of their malfunction
        agent.malfunction_data['malfunction'] -= 1
        if 'moving_before_malfunction' in agent.malfunction_data:
            agent.moving = agent.malfunction_data['moving_before_malfunction']
            return
461

462
    def _break_agent(self, agent: EnvAgent):
463
        """
464
        Malfunction generator that breaks agents at a given rate.
465

466 467 468
        Parameters
        ----------
        agent
469

470
        """
Erik Nygren's avatar
Erik Nygren committed
471

472 473 474 475 476
        if "generate" in dir(self.malfunction_generator):
            malfunction: mal_gen.Malfunction = self.malfunction_generator.generate(agent, self.np_random)
        else:
            malfunction: mal_gen.Malfunction = self.malfunction_generator(agent, self.np_random)

Erik Nygren's avatar
Erik Nygren committed
477 478 479 480 481
        if malfunction.num_broken_steps > 0:
            agent.malfunction_data['malfunction'] = malfunction.num_broken_steps
            agent.malfunction_data['moving_before_malfunction'] = agent.moving
            agent.malfunction_data['nr_malfunctions'] += 1

482
        return
u214892's avatar
u214892 committed
483

484
    def step(self, action_dict_: Dict[int, RailEnvActions]):
485 486
        """
        Updates rewards for the agents at a step.
487

488 489 490
        Parameters
        ----------
        action_dict_ : Dict[int,RailEnvActions]
491

492
        """
spiglerg's avatar
spiglerg committed
493 494
        self._elapsed_steps += 1

495
        # If we're done, set reward and info_dict and step() is done.
496
        if self.dones["__all__"]:
497
            self.rewards_dict = {}
u214892's avatar
u214892 committed
498
            info_dict = {
499 500 501 502
                "action_required": {},
                "malfunction": {},
                "speed": {},
                "status": {},
u214892's avatar
u214892 committed
503
            }
504
            for i_agent, agent in enumerate(self.agents):
505 506 507 508
                self.rewards_dict[i_agent] = self.global_reward
                info_dict["action_required"][i_agent] = False
                info_dict["malfunction"][i_agent] = 0
                info_dict["speed"][i_agent] = 0
509
                info_dict["status"][i_agent] = agent.status
510

u214892's avatar
u214892 committed
511
            return self._get_observations(), self.rewards_dict, self.dones, info_dict
512

513 514 515
        # Reset the step rewards
        self.rewards_dict = dict()
        info_dict = {
516 517 518 519
            "action_required": {},
            "malfunction": {},
            "speed": {},
            "status": {},
520
        }
521
        have_all_agents_ended = True  # boolean flag to check if all agents are done
522

523
        self.motionCheck = ac.MotionCheck()  # reset the motion check
524

525 526 527 528
        if not self.close_following:
            for i_agent, agent in enumerate(self.agents):
                # Reset the step rewards
                self.rewards_dict[i_agent] = 0
529

530 531
                # Induce malfunction before we do a step, thus a broken agent can't move in this step
                self._break_agent(agent)
spiglerg's avatar
spiglerg committed
532

533 534
                # Perform step on the agent
                self._step_agent(i_agent, action_dict_.get(i_agent))
535

536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565
                # manage the boolean flag to check if all agents are indeed done (or done_removed)
                have_all_agents_ended &= (agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED])

                # Build info dict
                info_dict["action_required"][i_agent] = self.action_required(agent)
                info_dict["malfunction"][i_agent] = agent.malfunction_data['malfunction']
                info_dict["speed"][i_agent] = agent.speed_data['speed']
                info_dict["status"][i_agent] = agent.status

                # Fix agents that finished their malfunction such that they can perform an action in the next step
                self._fix_agent_after_malfunction(agent)


        else:
            for i_agent, agent in enumerate(self.agents):
                # Reset the step rewards
                self.rewards_dict[i_agent] = 0

                # Induce malfunction before we do a step, thus a broken agent can't move in this step
                self._break_agent(agent)

                # Perform step on the agent
                self._step_agent_cf(i_agent, action_dict_.get(i_agent))

            # second loop: check for collisions / conflicts
            self.motionCheck.find_conflicts()

            # third loop: update positions
            for i_agent, agent in enumerate(self.agents):
                self._step_agent2_cf(i_agent)
566

567 568 569 570 571 572 573 574 575 576 577
                # manage the boolean flag to check if all agents are indeed done (or done_removed)
                have_all_agents_ended &= (agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED])

                # Build info dict
                info_dict["action_required"][i_agent] = self.action_required(agent)
                info_dict["malfunction"][i_agent] = agent.malfunction_data['malfunction']
                info_dict["speed"][i_agent] = agent.speed_data['speed']
                info_dict["status"][i_agent] = agent.status

                # Fix agents that finished their malfunction such that they can perform an action in the next step
                self._fix_agent_after_malfunction(agent)
578

579
        # Check for end of episode + set global reward to all rewards!
580
        if have_all_agents_ended:
581
            self.dones["__all__"] = True
582
            self.rewards_dict = {i: self.global_reward for i in range(self.get_num_agents())}
spiglerg's avatar
spiglerg committed
583 584
        if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps):
            self.dones["__all__"] = True
585
            for i_agent in range(self.get_num_agents()):
spmohanty's avatar
spmohanty committed
586
                self.dones[i_agent] = True
587
        if self.record_steps:
588
            self.record_timestep(action_dict_)
589

u214892's avatar
u214892 committed
590
        return self._get_observations(), self.rewards_dict, self.dones, info_dict
591

592
    def _step_agent(self, i_agent, action: Optional[RailEnvActions] = None):
593 594 595 596 597
        """
        Performs a step and step, start and stop penalty on a single agent in the following sub steps:
        - malfunction
        - action handling if at the beginning of cell
        - movement
598

599 600 601 602 603 604
        Parameters
        ----------
        i_agent : int
        action_dict_ : Dict[int,RailEnvActions]

        """
u214892's avatar
u214892 committed
605
        agent = self.agents[i_agent]
u214892's avatar
u214892 committed
606
        if agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]:  # this agent has already completed...
607 608
            return

u214892's avatar
u214892 committed
609 610
        # agent gets active by a MOVE_* action and if c
        if agent.status == RailAgentStatus.READY_TO_DEPART:
611 612 613 614
            initial_cell_free = self.cell_free(agent.initial_position)
            is_action_starting = action in [
                RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_FORWARD]

u214892's avatar
u214892 committed
615
            if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT,
616
                          RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position):
u214892's avatar
u214892 committed
617
                agent.status = RailAgentStatus.ACTIVE
618
                self._set_agent_to_initial_position(agent, agent.initial_position)
619
                self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
620
                return
u214892's avatar
u214892 committed
621
            else:
622 623
                # TODO: Here we need to check for the departure time in future releases with full schedules
                self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
u214892's avatar
u214892 committed
624 625
                return

626 627 628
        agent.old_direction = agent.direction
        agent.old_position = agent.position

629 630
        # if agent is broken, actions are ignored and agent does not move.
        # full step penalty in this case
631
        if agent.malfunction_data['malfunction'] > 0:
632 633 634 635
            self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
            return

        # Is the agent at the beginning of the cell? Then, it can take an action.
636
        # As long as the agent is malfunctioning or stopped at the beginning of the cell,
637
        # different actions may be taken!
638
        if fast_isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03):
639
            # No action has been supplied for this agent -> set DO_NOTHING as default
640
            if action is None:
641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658
                action = RailEnvActions.DO_NOTHING

            if action < 0 or action > len(RailEnvActions):
                print('ERROR: illegal action=', action,
                      'for agent with index=', i_agent,
                      '"DO NOTHING" will be executed instead')
                action = RailEnvActions.DO_NOTHING

            if action == RailEnvActions.DO_NOTHING and agent.moving:
                # Keep moving
                action = RailEnvActions.MOVE_FORWARD

            if action == RailEnvActions.STOP_MOVING and agent.moving:
                # Only allow halting an agent on entering new cells.
                agent.moving = False
                self.rewards_dict[i_agent] += self.stop_penalty

            if not agent.moving and not (
659 660
                action == RailEnvActions.DO_NOTHING or
                action == RailEnvActions.STOP_MOVING):
661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697
                # Allow agent to start with any forward or direction action
                agent.moving = True
                self.rewards_dict[i_agent] += self.start_penalty

            # Store the action if action is moving
            # If not moving, the action will be stored when the agent starts moving again.
            if agent.moving:
                _action_stored = False
                _, new_cell_valid, new_direction, new_position, transition_valid = \
                    self._check_action_on_agent(action, agent)

                if all([new_cell_valid, transition_valid]):
                    agent.speed_data['transition_action_on_cellexit'] = action
                    _action_stored = True
                else:
                    # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
                    # try to keep moving forward!
                    if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT):
                        _, new_cell_valid, new_direction, new_position, transition_valid = \
                            self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)

                        if all([new_cell_valid, transition_valid]):
                            agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
                            _action_stored = True

                if not _action_stored:
                    # If the agent cannot move due to an invalid transition, we set its state to not moving
                    self.rewards_dict[i_agent] += self.invalid_action_penalty
                    self.rewards_dict[i_agent] += self.stop_penalty
                    agent.moving = False

        # Now perform a movement.
        # If agent.moving, increment the position_fraction by the speed of the agent
        # If the new position fraction is >= 1, reset to 0, and perform the stored
        #   transition_action_on_cellexit if the cell is free.
        if agent.moving:
            agent.speed_data['position_fraction'] += agent.speed_data['speed']
698 699
            if agent.speed_data['position_fraction'] > 1.0 or fast_isclose(agent.speed_data['position_fraction'], 1.0,
                                                                           rtol=1e-03):
700 701 702 703
                # Perform stored action to transition to the next cell as soon as cell is free
                # Notice that we've already checked new_cell_valid and transition valid when we stored the action,
                # so we only have to check cell_free now!

704
                # Traditional check that next cell is free
705 706
                # cell and transition validity was checked when we stored transition_action_on_cellexit!
                cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent(
707
                    agent.speed_data['transition_action_on_cellexit'], agent)
708 709 710 711 712

                # N.B. validity of new_cell and transition should have been verified before the action was stored!
                assert new_cell_valid
                assert transition_valid
                if cell_free:
713
                    self._move_agent_to_new_position(agent, new_position)
714 715
                    agent.direction = new_direction
                    agent.speed_data['position_fraction'] = 0.0
716 717 718 719 720 721 722 723 724 725 726 727 728 729 730

            # has the agent reached its target?
            if np.equal(agent.position, agent.target).all():
                agent.status = RailAgentStatus.DONE
                self.dones[i_agent] = True
                self.active_agents.remove(i_agent)
                agent.moving = False
                self._remove_agent_from_scene(agent)
            else:
                self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
        else:
            # step penalty if not moving (stopped now or before)
            self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']

    def _step_agent_cf(self, i_agent, action: Optional[RailEnvActions] = None):
731 732
        """ "close following" version of step_agent.
        """
733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752
        agent = self.agents[i_agent]
        if agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]:  # this agent has already completed...
            return

        # agent gets active by a MOVE_* action and if c
        if agent.status == RailAgentStatus.READY_TO_DEPART:
            is_action_starting = action in [
                RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_FORWARD]

            if is_action_starting:  # agent is trying to start
                self.motionCheck.addAgent(i_agent, None, agent.initial_position)
            else:  # agent wants to remain unstarted
                self.motionCheck.addAgent(i_agent, None, None)
            return

        agent.old_direction = agent.direction
        agent.old_position = agent.position

        # if agent is broken, actions are ignored and agent does not move.
        # full step penalty in this case
753 754
        # TODO: this means that deadlocked agents which suffer a malfunction are marked as 
        # stopped rather than deadlocked.
755 756
        if agent.malfunction_data['malfunction'] > 0:
            self.motionCheck.addAgent(i_agent, agent.position, agent.position)
757
            # agent will get penalty in step_agent2_cf
Erik Nygren's avatar
Erik Nygren committed
758
            # self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789
            return

        # Is the agent at the beginning of the cell? Then, it can take an action.
        # As long as the agent is malfunctioning or stopped at the beginning of the cell,
        # different actions may be taken!
        if np.isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03):
            # No action has been supplied for this agent -> set DO_NOTHING as default
            if action is None:
                action = RailEnvActions.DO_NOTHING

            if action < 0 or action > len(RailEnvActions):
                print('ERROR: illegal action=', action,
                      'for agent with index=', i_agent,
                      '"DO NOTHING" will be executed instead')
                action = RailEnvActions.DO_NOTHING

            if action == RailEnvActions.DO_NOTHING and agent.moving:
                # Keep moving
                action = RailEnvActions.MOVE_FORWARD

            if action == RailEnvActions.STOP_MOVING and agent.moving:
                # Only allow halting an agent on entering new cells.
                agent.moving = False
                self.rewards_dict[i_agent] += self.stop_penalty

            if not agent.moving and not (
                action == RailEnvActions.DO_NOTHING or
                action == RailEnvActions.STOP_MOVING):
                # Allow agent to start with any forward or direction action
                agent.moving = True
                self.rewards_dict[i_agent] += self.start_penalty
790

791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817
            # Store the action if action is moving
            # If not moving, the action will be stored when the agent starts moving again.
            new_position = None
            if agent.moving:
                _action_stored = False
                _, new_cell_valid, new_direction, new_position, transition_valid = \
                    self._check_action_on_agent(action, agent)

                if all([new_cell_valid, transition_valid]):
                    agent.speed_data['transition_action_on_cellexit'] = action
                    _action_stored = True
                else:
                    # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
                    # try to keep moving forward!
                    if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT):
                        _, new_cell_valid, new_direction, new_position, transition_valid = \
                            self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)

                        if all([new_cell_valid, transition_valid]):
                            agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
                            _action_stored = True

                if not _action_stored:
                    # If the agent cannot move due to an invalid transition, we set its state to not moving
                    self.rewards_dict[i_agent] += self.invalid_action_penalty
                    self.rewards_dict[i_agent] += self.stop_penalty
                    agent.moving = False
818
                    self.motionCheck.addAgent(i_agent, agent.position, agent.position)
819 820 821 822 823 824
                    return

            if new_position is None:
                self.motionCheck.addAgent(i_agent, agent.position, agent.position)
                if agent.moving:
                    print("Agent", i_agent, "new_pos none, but moving")
825

826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846
        # Check the pos_frac position fraction
        if agent.moving:
            agent.speed_data['position_fraction'] += agent.speed_data['speed']
            if agent.speed_data['position_fraction'] > 0.999:
                stored_action = agent.speed_data["transition_action_on_cellexit"]

                # find the next cell using the stored action
                _, new_cell_valid, new_direction, new_position, transition_valid = \
                    self._check_action_on_agent(stored_action, agent)

                # if it's valid, record it as the new position
                if all([new_cell_valid, transition_valid]):
                    self.motionCheck.addAgent(i_agent, agent.position, new_position)
                else:  # if the action wasn't valid then record the agent as stationary
                    self.motionCheck.addAgent(i_agent, agent.position, agent.position)
            else:  # This agent hasn't yet crossed the cell
                self.motionCheck.addAgent(i_agent, agent.position, agent.position)

    def _step_agent2_cf(self, i_agent):
        agent = self.agents[i_agent]

Erik Nygren's avatar
Erik Nygren committed
847
        if agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]:
848 849 850 851 852 853
            return

        (move, rc_next) = self.motionCheck.check_motion(i_agent, agent.position)

        if agent.position is not None:
            sbTrans = format(self.rail.grid[agent.position], "016b")
Erik Nygren's avatar
Erik Nygren committed
854
            trans_block = sbTrans[agent.direction * 4: agent.direction * 4 + 4]
855 856 857 858
            if (trans_block == "0000"):
                print (i_agent, agent.position, agent.direction, sbTrans, trans_block)

        # if agent cannot enter env, then we should have move=False
859

860 861
        if move:
            if agent.position is None:  # agent is entering the env
Erik Nygren's avatar
Erik Nygren committed
862
                # print(i_agent, "writing new pos ", rc_next, " into agent position (None)")
863 864 865 866 867 868 869
                agent.position = rc_next
                agent.status = RailAgentStatus.ACTIVE
                agent.speed_data['position_fraction'] = 0.0

            else:  # normal agent move
                cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent(
                    agent.speed_data['transition_action_on_cellexit'], agent)
870

871 872 873 874 875
                if not all([transition_valid, new_cell_valid]):
                    print(f"ERRROR: step_agent2 invalid transition ag {i_agent} dir {new_direction} pos {agent.position} next {rc_next}")

                if new_position != rc_next:
                    print(f"ERROR: agent {i_agent} new_pos {new_position} != rc_next {rc_next}  " + 
876 877
                          f"pos {agent.position} dir {agent.direction} new_dir {new_direction}" +
                          f"stored action: {agent.speed_data['transition_action_on_cellexit']}")
878 879

                sbTrans = format(self.rail.grid[agent.position], "016b")
Erik Nygren's avatar
Erik Nygren committed
880
                trans_block = sbTrans[agent.direction * 4: agent.direction * 4 + 4]
881
                if (trans_block == "0000"):
882
                    print ("ERROR: ", i_agent, agent.position, agent.direction, sbTrans, trans_block)
883 884 885

                agent.position = rc_next
                agent.direction = new_direction
886 887
                agent.speed_data['position_fraction'] = 0.0

888 889
            # has the agent reached its target?
            if np.equal(agent.position, agent.target).all():
u214892's avatar
u214892 committed
890
                agent.status = RailAgentStatus.DONE
891
                self.dones[i_agent] = True
892
                self.active_agents.remove(i_agent)
893
                agent.moving = False
894
                self._remove_agent_from_scene(agent)
895 896
            else:
                self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
897 898 899
        else:
            # step penalty if not moving (stopped now or before)
            self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
900

901
    def _set_agent_to_initial_position(self, agent: EnvAgent, new_position: IntVector2D):
902 903 904 905 906 907 908 909 910
        """
        Sets the agent to its initial position. Updates the agent object and the position
        of the agent inside the global agent_position numpy array

        Parameters
        -------
        agent: EnvAgent object
        new_position: IntVector2D
        """
911
        agent.position = new_position
912
        self.agent_positions[agent.position] = agent.handle
913 914

    def _move_agent_to_new_position(self, agent: EnvAgent, new_position: IntVector2D):
915 916 917 918 919 920 921 922 923
        """
        Move the agent to the a new position. Updates the agent object and the position
        of the agent inside the global agent_position numpy array

        Parameters
        -------
        agent: EnvAgent object
        new_position: IntVector2D
        """
924
        agent.position = new_position
925 926
        self.agent_positions[agent.old_position] = -1
        self.agent_positions[agent.position] = agent.handle
927 928

    def _remove_agent_from_scene(self, agent: EnvAgent):
929 930 931 932 933 934 935 936
        """
        Remove the agent from the scene. Updates the agent object and the position
        of the agent inside the global agent_position numpy array

        Parameters
        -------
        agent: EnvAgent object
        """
937
        self.agent_positions[agent.position] = -1
938 939
        if self.remove_agents_at_target:
            agent.position = None
940
            # setting old_position to None here stops the DONE agents from appearing in the rendered image
941
            agent.old_position = None
942 943
            agent.status = RailAgentStatus.DONE_REMOVED

944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959
    def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent):
        """

        Parameters
        ----------
        action : RailEnvActions
        agent : EnvAgent

        Returns
        -------
        bool
            Is it a legal move?
            1) transition allows the new_direction in the cell,
            2) the new cell is not empty (case 0),
            3) the cell is free, i.e., no agent is currently in that cell

u214892's avatar
u214892 committed
960

961
        """
u214892's avatar
u214892 committed
962 963
        # compute number of possible transitions in the current
        # cell used to check for invalid actions
964
        new_direction, transition_valid = self.check_action(agent, action)
u214892's avatar
u214892 committed
965
        new_position = get_new_position(agent.position, new_direction)
966

967
        new_cell_valid = (
968
            fast_position_equal(  # Check the new position is still in the grid
spiglerg's avatar
spiglerg committed
969
                new_position,
970
                fast_clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
spiglerg's avatar
spiglerg committed
971
            and  # check the new position has some transitions (ie is not an empty cell)
u214892's avatar
u214892 committed
972
            self.rail.get_full_transitions(*new_position) > 0)
973

spiglerg's avatar
spiglerg committed
974
        # If transition validity hasn't been checked yet.
975 976
        if transition_valid is None:
            transition_valid = self.rail.get_transition(
spiglerg's avatar
spiglerg committed
977 978
                (*agent.position, agent.direction),
                new_direction)
979

980 981 982 983 984 985 986 987
        # only call cell_free() if new cell is inside the scene
        if new_cell_valid:
            # Check the new position is not the same as any of the existing agent positions
            # (including itself, for simplicity, since it is moving)
            cell_free = self.cell_free(new_position)
        else:
            # if new cell is outside of scene -> cell_free is False
            cell_free = False
988
        return cell_free, new_cell_valid, new_direction, new_position, transition_valid
spiglerg's avatar
spiglerg committed
989

990
    def record_timestep(self, dActions):
991 992 993 994 995 996 997 998 999 1000 1001 1002
        ''' Record the positions and orientations of all agents in memory, in the cur_episode
        '''
        list_agents_state = []
        for i_agent in range(self.get_num_agents()):
            agent = self.agents[i_agent]
            # the int cast is to avoid numpy types which may cause problems with msgpack
            # in env v2, agents may have position None, before starting
            if agent.position is None:
                pos = (0, 0)
            else:
                pos = (int(agent.position[0]), int(agent.position[1]))
            # print("pos:", pos, type(pos[0]))
1003 1004 1005
            list_agents_state.append([
                    *pos, int(agent.direction), 
                    agent.malfunction_data["malfunction"],  
1006 1007
                    int(agent.status),
                    int(agent.position in self.motionCheck.svDeadlocked)
1008
                    ])
1009

1010
        self.cur_episode.append(list_agents_state)
1011
        self.list_actions.append(dActions)
1012

1013
    def cell_free(self, position: IntVector2D) -> bool:
1014 1015
        """
        Utility to check if a cell is free
u214892's avatar
u214892 committed
1016

1017 1018 1019 1020 1021 1022 1023 1024 1025 1026
        Parameters:
        --------
        position : Tuple[int, int]

        Returns
        -------
        bool
            is the cell free or not?

        """
1027
        return self.agent_positions[position] == -1
u214892's avatar
u214892 committed
1028

1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043
    def check_action(self, agent: EnvAgent, action: RailEnvActions):
        """

        Parameters
        ----------
        agent : EnvAgent
        action : RailEnvActions

        Returns
        -------
        Tuple[Grid4TransitionsEnum,Tuple[int,int]]



        """
1044
        transition_valid = None
u214892's avatar
u214892 committed
1045
        possible_transitions = self.rail.get_transitions(*agent.position, agent.direction)
1046
        num_transitions = fast_count_nonzero(possible_transitions)
hagrid67's avatar
hagrid67 committed
1047 1048

        new_direction = agent.direction
spiglerg's avatar
spiglerg committed
1049
        if action == RailEnvActions.MOVE_LEFT:
hagrid67's avatar
hagrid67 committed
1050 1051
            new_direction = agent.direction - 1
            if num_transitions <= 1:
1052
                transition_valid = False
hagrid67's avatar
hagrid67 committed
1053

spiglerg's avatar
spiglerg committed
1054
        elif action == RailEnvActions.MOVE_RIGHT:
hagrid67's avatar
hagrid67 committed
1055 1056
            new_direction = agent.direction + 1
            if num_transitions <= 1:
1057
                transition_valid = False
hagrid67's avatar
hagrid67 committed
1058 1059 1060

        new_direction %= 4

1061 1062 1063 1064
        if action == RailEnvActions.MOVE_FORWARD and num_transitions == 1:
            # - dead-end, straight line or curved line;
            # new_direction will be the only valid transition
            # - take only available transition
1065
            new_direction = fast_argmax(possible_transitions)
1066
            transition_valid = True
1067
        return new_direction, transition_valid
hagrid67's avatar
hagrid67 committed
1068

1069
    def _get_observations(self):
1070 1071 1072 1073 1074 1075 1076
        """
        Utility which returns the observations for an agent with respect to environment

        Returns
        ------
        Dict object
        """
1077
        # print(f"_get_obs - num agents: {self.get_num_agents()} {list(range(self.get_num_agents()))}")
1078
        self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
1079
        return self.obs_dict
1080

u214892's avatar
u214892 committed
1081
    def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]:
1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093
        """
        Returns directions in which the agent can move

        Parameters:
        ---------
        row : int
        col : int

        Returns:
        -------
        List[int]
        """
1094
        return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col))
u214892's avatar
u214892 committed