rail_env.py 48.2 KB
Newer Older
1
"""
2
Definition of the RailEnv environment.
3
"""
4
import random
hagrid67's avatar
hagrid67 committed
5
# TODO:  _ this is a global method --> utils or remove later
6
from enum import IntEnum
7
from typing import List, NamedTuple, Optional, Dict, Tuple
8

9
import numpy as np
10

11

12
from flatland.core.env import Environment
13
from flatland.core.env_observation_builder import ObservationBuilder
u214892's avatar
u214892 committed
14
from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
u214892's avatar
u214892 committed
15
from flatland.core.grid.grid4_utils import get_new_position
16
from flatland.core.grid.grid_utils import IntVector2D
u214892's avatar
u214892 committed
17
from flatland.core.transition_map import GridTransitionMap
u229589's avatar
u229589 committed
18
from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
19
from flatland.envs.distance_map import DistanceMap
20
from flatland.envs.rail_env_action import RailEnvActions
hagrid67's avatar
hagrid67 committed
21
22

# Need to use circular imports for persistence.
23
24
from flatland.envs import malfunction_generators as mal_gen
from flatland.envs import rail_generators as rail_gen
25
from flatland.envs import line_generators as line_gen
26
from flatland.envs.timetable_generators import timetable_generator
hagrid67's avatar
hagrid67 committed
27
from flatland.envs import persistence
28
from flatland.envs import agent_chains as ac
hagrid67's avatar
hagrid67 committed
29

30
31
32
from flatland.envs.observations import GlobalObsForRailEnv
from gym.utils import seeding

33
34
35
36
# Direct import of objects / classes does not work with circular imports.
# from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionProcessData
# from flatland.envs.observations import GlobalObsForRailEnv
# from flatland.envs.rail_generators import random_rail_generator, RailGenerator
37
38
# from flatland.envs.line_generators import random_line_generator, LineGenerator

hagrid67's avatar
hagrid67 committed
39

40

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# Adrian Egli performance fix (the fast methods brings more than 50%)
def fast_isclose(a, b, rtol):
    return (a < (b + rtol)) or (a < (b - rtol))


def fast_clip(position: (int, int), min_value: (int, int), max_value: (int, int)) -> bool:
    return (
        max(min_value[0], min(position[0], max_value[0])),
        max(min_value[1], min(position[1], max_value[1]))
    )


def fast_argmax(possible_transitions: (int, int, int, int)) -> bool:
    if possible_transitions[0] == 1:
        return 0
    if possible_transitions[1] == 1:
        return 1
    if possible_transitions[2] == 1:
        return 2
    return 3


def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
    return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]


def fast_count_nonzero(possible_transitions: (int, int, int, int)):
    return possible_transitions[0] + possible_transitions[1] + possible_transitions[2] + possible_transitions[3]


u214892's avatar
u214892 committed
71

72
73
74
75
76
77
78
79
80
81
class RailEnv(Environment):
    """
    RailEnv environment class.

    RailEnv is an environment inspired by a (simplified version of) a rail
    network, in which agents (trains) have to navigate to their target
    locations in the shortest time possible, while at the same time cooperating
    to avoid bottlenecks.

    The valid actions in the environment are:
82
83
84
85
86
87

     -   0: do nothing (continue moving or stay still)
     -   1: turn left at switch and move to the next cell; if the agent was not moving, movement is started
     -   2: move to the next cell in front of the agent; if the agent was not moving, movement is started
     -   3: turn right at switch and move to the next cell; if the agent was not moving, movement is started
     -   4: stop moving
88
89
90
91

    Moving forward in a dead-end cell makes the agent turn 180 degrees and step
    to the cell it came from.

92

93
94
95
    The actions of the agents are executed in order of their handle to prevent
    deadlocks and to allow them to learn relative priorities.

96
97
98
99
100
101
102
103
104
105
106
107
    Reward Function:

    It costs each agent a step_penalty for every time-step taken in the environment. Independent of the movement
    of the agent. Currently all other penalties such as penalty for stopping, starting and invalid actions are set to 0.

    alpha = 1
    beta = 1
    Reward function parameters:

    - invalid_action_penalty = 0
    - step_penalty = -alpha
    - global_reward = beta
108
    - epsilon = avoid rounding errors
109
110
111
    - stop_penalty = 0  # penalty for stopping a moving agent
    - start_penalty = 0  # penalty for starting a stopped agent

112
113
    Stochastic malfunctioning of trains:
    Trains in RailEnv can malfunction if they are halted too often (either by their own choice or because an invalid
spiglerg's avatar
spiglerg committed
114
115
    action or cell is selected.

116
117
118
    Every time an agent stops, an agent has a certain probability of malfunctioning. Malfunctions of trains follow a
    poisson process with a certain rate. Not all trains will be affected by malfunctions during episodes to keep
    complexity managable.
spiglerg's avatar
spiglerg committed
119
120
121
122

    TODO: currently, the parameters that control the stochasticity of the environment are hard-coded in init().
    For Round 2, they will be passed to the constructor as arguments, to allow for more flexibility.

123
    """
u214892's avatar
u214892 committed
124
125
    # Epsilon to avoid rounding errors
    epsilon = 0.01
126
127
128
    # NEW : REW: Sparse Reward
    alpha = 0
    beta = 0
u214892's avatar
u214892 committed
129
130
    step_penalty = -1 * alpha
    global_reward = 1 * beta
131
    invalid_action_penalty = 0  # previously -2; GIACOMO: we decided that invalid actions will carry no penalty
u214892's avatar
u214892 committed
132
133
    stop_penalty = 0  # penalty for stopping a moving agent
    start_penalty = 0  # penalty for starting a stopped agent
134
135
    cancellation_factor = 1
    cancellation_time_buffer = 0
136
137
138
139

    def __init__(self,
                 width,
                 height,
140
                 rail_generator=None,
141
                 line_generator=None,  # : line_gen.LineGenerator = line_gen.random_line_generator(),
142
                 number_of_agents=2,
Erik Nygren's avatar
Erik Nygren committed
143
                 obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(),
Erik Nygren's avatar
Erik Nygren committed
144
                 malfunction_generator_and_process_data=None,  # mal_gen.no_malfunction_generator(),
145
                 malfunction_generator=None,
146
                 remove_agents_at_target=True,
147
                 random_seed=1,
148
149
                 record_steps=False,
                 close_following=True
u214892's avatar
u214892 committed
150
                 ):
151
152
153
154
        """
        Environment init.

        Parameters
155
        ----------
156
        rail_generator : function
157
158
159
160
            The rail_generator function is a function that takes the width,
            height and agents handles of a  rail environment, along with the number of times
            the env has been reset, and returns a GridTransitionMap object and a list of
            starting positions, targets, and initial orientations for agent handle.
161
            The rail_generator can pass a distance map in the hints or information for specific line_generators.
u214892's avatar
u214892 committed
162
            Implementations can be found in flatland/envs/rail_generators.py
163
164
        line_generator : function
            The line_generator function is a function that takes the grid, the number of agents and optional hints
u214892's avatar
u214892 committed
165
            and returns a list of starting positions, targets, initial orientations and speed for all agent handles.
166
            Implementations can be found in flatland/envs/line_generators.py
167
168
169
170
171
172
173
174
175
176
177
178
        width : int
            The width of the rail map. Potentially in the future,
            a range of widths to sample from.
        height : int
            The height of the rail map. Potentially in the future,
            a range of heights to sample from.
        number_of_agents : int
            Number of agents to spawn on the map. Potentially in the future,
            a range of number of agents to sample from.
        obs_builder_object: ObservationBuilder object
            ObservationBuilder-derived object that takes builds observation
            vectors for each agent.
179
180
181
        remove_agents_at_target : bool
            If remove_agents_at_target is set to true then the agents will be removed by placing to
            RailEnv.DEPOT_POSITION when the agent has reach it's target position.
182
183
184
        random_seed : int or None
            if None, then its ignored, else the random generators are seeded with this number to ensure
            that stochastic operations are replicable across multiple operations
185
        """
186
        super().__init__()
187

188
189
190
191
192
193
        if malfunction_generator_and_process_data is not None:
            print("DEPRECATED - RailEnv arg: malfunction_and_process_data - use malfunction_generator")
            self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
        elif malfunction_generator is not None:
            self.malfunction_generator = malfunction_generator
            # malfunction_process_data is not used
Erik Nygren's avatar
Erik Nygren committed
194
            # self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
195
196
197
198
199
            self.malfunction_process_data = self.malfunction_generator.get_process_data()
        # replace default values here because we can't use default args values because of cyclic imports
        else:
            self.malfunction_generator = mal_gen.NoMalfunctionGen()
            self.malfunction_process_data = self.malfunction_generator.get_process_data()
200
        
201
        self.number_of_agents = number_of_agents
202

Erik Nygren's avatar
Erik Nygren committed
203
        # self.rail_generator: RailGenerator = rail_generator
204
        if rail_generator is None:
205
            rail_generator = rail_gen.sparse_rail_generator()
206
        self.rail_generator = rail_generator
207
        if line_generator is None:
208
            line_generator = line_gen.sparse_line_generator()
209
        self.line_generator = line_generator
210

u214892's avatar
u214892 committed
211
        self.rail: Optional[GridTransitionMap] = None
212
213
        self.width = width
        self.height = height
Erik Nygren's avatar
Erik Nygren committed
214

215
216
        self.remove_agents_at_target = remove_agents_at_target

Erik Nygren's avatar
Erik Nygren committed
217
        self.rewards = [0] * number_of_agents
218
        self.done = False
219
        self.obs_builder = obs_builder_object
u229589's avatar
u229589 committed
220
        self.obs_builder.set_env(self)
221

222
        self._max_episode_steps: Optional[int] = None
spiglerg's avatar
spiglerg committed
223
224
        self._elapsed_steps = 0

Erik Nygren's avatar
Erik Nygren committed
225
        self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False)
226

227
228
        self.obs_dict = {}
        self.rewards_dict = {}
Erik Nygren's avatar
Erik Nygren committed
229
        self.dev_obs_dict = {}
Erik Nygren's avatar
Erik Nygren committed
230
        self.dev_pred_dict = {}
231

u229589's avatar
u229589 committed
232
        self.agents: List[EnvAgent] = []
233
        self.num_resets = 0
234
        self.distance_map = DistanceMap(self.agents, self.height, self.width)
235

Erik Nygren's avatar
Erik Nygren committed
236
        self.action_space = [5]
237

238
        self._seed()
239
240
241
242
243
        self._seed()
        self.random_seed = random_seed
        if self.random_seed:
            self._seed(seed=random_seed)

244
245
        self.valid_positions = None

246
247
        # global numpy array of agents position, True means that there is an agent at that cell
        self.agent_positions: np.ndarray = np.full((height, width), False)
248

249
250
        # save episode timesteps ie agent positions, orientations.  (not yet actions / observations)
        self.record_steps = record_steps  # whether to save timesteps
251
        # save timesteps in here: [[[row, col, dir, malfunction],...nAgents], ...nSteps]
252
253
        self.cur_episode = []
        self.list_actions = []  # save actions in here
254

255
        self.close_following = close_following  # use close following logic
256
257
        self.motionCheck = ac.MotionCheck()

258
259
    def _seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
260
        random.seed(seed)
261
262
        return [seed]

263
    # no more agent_handles
264
    def get_agent_handles(self):
265
266
        return range(self.get_num_agents())

u229589's avatar
u229589 committed
267
268
    def get_num_agents(self) -> int:
        return len(self.agents)
269

u229589's avatar
u229589 committed
270
    def add_agent(self, agent):
hagrid67's avatar
hagrid67 committed
271
272
273
        """ Add static info for a single agent.
            Returns the index of the new agent.
        """
u229589's avatar
u229589 committed
274
275
        self.agents.append(agent)
        return len(self.agents) - 1
hagrid67's avatar
hagrid67 committed
276

277
    def set_agent_active(self, agent: EnvAgent):
278
        if agent.status == RailAgentStatus.READY_TO_DEPART or agent.status == RailAgentStatus.WAITING and self.cell_free(agent.initial_position): ## Dipam : Why is this code even there???
u214892's avatar
u214892 committed
279
            agent.status = RailAgentStatus.ACTIVE
280
            self._set_agent_to_initial_position(agent, agent.initial_position)
u214892's avatar
u214892 committed
281

Erik Nygren's avatar
Erik Nygren committed
282
    def reset_agents(self):
u229589's avatar
u229589 committed
283
        """ Reset the agents to their starting positions
hagrid67's avatar
hagrid67 committed
284
        """
u229589's avatar
u229589 committed
285
286
        for agent in self.agents:
            agent.reset()
287
        self.active_agents = [i for i in range(len(self.agents))]
Erik Nygren's avatar
Erik Nygren committed
288

Erik Nygren's avatar
Erik Nygren committed
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
    def action_required(self, agent):
        """
        Check if an agent needs to provide an action

        Parameters
        ----------
        agent: RailEnvAgent
        Agent we want to check

        Returns
        -------
        True: Agent needs to provide an action
        False: Agent cannot provide an action
        """
        return (agent.status == RailAgentStatus.READY_TO_DEPART or (
304
305
            agent.status == RailAgentStatus.ACTIVE and fast_isclose(agent.speed_data['position_fraction'], 0.0,
                                                                    rtol=1e-03)))
Erik Nygren's avatar
Erik Nygren committed
306

307
    def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False,
308
              random_seed: bool = None) -> Tuple[Dict, Dict]:
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
        """
        reset(regenerate_rail, regenerate_schedule, activate_agents, random_seed)

        The method resets the rail environment

        Parameters
        ----------
        regenerate_rail : bool, optional
            regenerate the rails
        regenerate_schedule : bool, optional
            regenerate the schedule and the static agents
        activate_agents : bool, optional
            activate the agents
        random_seed : bool, optional
            random seed for environment

        Returns
        -------
        observation_dict: Dict
            Dictionary with an observation for each agent
        info_dict: Dict with agent specific information

hagrid67's avatar
hagrid67 committed
331
        """
332

333
334
        if random_seed:
            self._seed(random_seed)
335

336
        optionals = {}
337
        if regenerate_rail or self.rail is None:
338
339
340
341
342
343
344
345
346
347

            if "__call__" in dir(self.rail_generator):
                rail, optionals = self.rail_generator(
                    self.width, self.height, self.number_of_agents, self.num_resets, self.np_random)
            elif "generate" in dir(self.rail_generator):
                rail, optionals = self.rail_generator.generate(
                    self.width, self.height, self.number_of_agents, self.num_resets, self.np_random)
            else:
                raise ValueError("Could not invoke __call__ or generate on rail_generator")

u214892's avatar
u214892 committed
348
            self.rail = rail
349
            self.height, self.width = self.rail.grid.shape
350

spmohanty's avatar
spmohanty committed
351
            # Do a new set_env call on the obs_builder to ensure
352
            # that obs_builder specific instantiations are made according to the
spmohanty's avatar
spmohanty committed
353
354
355
            # specifications of the current environment : like width, height, etc
            self.obs_builder.set_env(self)

356
        if optionals and 'distance_map' in optionals:
357
            self.distance_map.set(optionals['distance_map'])
358

359
        if regenerate_schedule or regenerate_rail or self.get_num_agents() == 0:
u214892's avatar
u214892 committed
360
361
362
            agents_hints = None
            if optionals and 'agents_hints' in optionals:
                agents_hints = optionals['agents_hints']
363

364
            line = self.line_generator(self.rail, self.number_of_agents, agents_hints, 
365
                                               self.num_resets, self.np_random)
366
            self.agents = EnvAgent.from_line(line)
367

368
369
            # Reset distance map - basically initializing
            self.distance_map.reset(self.agents, self.rail)
370

371
            # NEW : Time Schedule Generation
372
            timetable = timetable_generator(self.agents, self.distance_map, 
373
374
                                               agents_hints, self.np_random)

375
            self._max_episode_steps = timetable.max_episode_steps
376
377

            for agent_i, agent in enumerate(self.agents):
378
379
                agent.earliest_departure = timetable.earliest_departures[agent_i]         
                agent.latest_arrival = timetable.latest_arrivals[agent_i]
380
381
        else:
            self.distance_map.reset(self.agents, self.rail)
382

383
384
        # Agent Positions Map
        self.agent_positions = np.zeros((self.height, self.width), dtype=int) - 1
385
386
387
        
        # Reset agents to initial states
        self.reset_agents()
hagrid67's avatar
hagrid67 committed
388

389
        for agent in self.agents:
390
            # Induce malfunctions
391
392
393
            if activate_agents:
                self.set_agent_active(agent)

Erik Nygren's avatar
Erik Nygren committed
394
            self._break_agent(agent)
395

396
            if agent.malfunction_data["malfunction"] > 0:
397
                agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING
Erik Nygren's avatar
Erik Nygren committed
398

399
            # Fix agents that finished their malfunction
400
            self._fix_agent_after_malfunction(agent)
401

402
        self.num_resets += 1
spiglerg's avatar
spiglerg committed
403
        self._elapsed_steps = 0
404

u214892's avatar
u214892 committed
405
        # TODO perhaps dones should be part of each agent.
406
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
407

408
409
410
        # Reset the state of the observation builder with the new environment
        self.obs_builder.reset()

Erik Nygren's avatar
Erik Nygren committed
411
        # Reset the malfunction generator
412
413
414
415
        if "generate" in dir(self.malfunction_generator):
            self.malfunction_generator.generate(reset=True)
        else:
            self.malfunction_generator(reset=True)
Erik Nygren's avatar
Erik Nygren committed
416

417
418
419
        # Empty the episode store of agent positions
        self.cur_episode = []

420
        info_dict: Dict = {
Erik Nygren's avatar
Erik Nygren committed
421
            'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)},
Erik Nygren's avatar
Erik Nygren committed
422
            'malfunction': {
u229589's avatar
u229589 committed
423
                i: agent.malfunction_data['malfunction'] for i, agent in enumerate(self.agents)
Erik Nygren's avatar
Erik Nygren committed
424
            },
u229589's avatar
u229589 committed
425
            'speed': {i: agent.speed_data['speed'] for i, agent in enumerate(self.agents)},
Erik Nygren's avatar
Erik Nygren committed
426
427
            'status': {i: agent.status for i, agent in enumerate(self.agents)}
        }
428
        # Return the new observation vectors for each agent
429
430
        observation_dict: Dict = self._get_observations()
        return observation_dict, info_dict
431

432
    def _fix_agent_after_malfunction(self, agent: EnvAgent):
u214892's avatar
u214892 committed
433
        """
434
        Updates agent malfunction variables and fixes broken agents
u214892's avatar
u214892 committed
435

436
437
438
439
        Parameters
        ----------
        agent
        """
440

441
        # Ignore agents that are OK
442
        if self._is_agent_ok(agent):
443
            return
444

445
446
        # Reduce number of malfunction steps left
        if agent.malfunction_data['malfunction'] > 1:
447
            agent.malfunction_data['malfunction'] -= 1
448
            return
449

450
451
452
453
454
        # Restart agents at the end of their malfunction
        agent.malfunction_data['malfunction'] -= 1
        if 'moving_before_malfunction' in agent.malfunction_data:
            agent.moving = agent.malfunction_data['moving_before_malfunction']
            return
455

456
    def _break_agent(self, agent: EnvAgent):
457
        """
458
        Malfunction generator that breaks agents at a given rate.
459

460
461
462
        Parameters
        ----------
        agent
463

464
        """
Erik Nygren's avatar
Erik Nygren committed
465

466
467
468
469
470
        if "generate" in dir(self.malfunction_generator):
            malfunction: mal_gen.Malfunction = self.malfunction_generator.generate(agent, self.np_random)
        else:
            malfunction: mal_gen.Malfunction = self.malfunction_generator(agent, self.np_random)

Erik Nygren's avatar
Erik Nygren committed
471
472
473
474
475
        if malfunction.num_broken_steps > 0:
            agent.malfunction_data['malfunction'] = malfunction.num_broken_steps
            agent.malfunction_data['moving_before_malfunction'] = agent.moving
            agent.malfunction_data['nr_malfunctions'] += 1

476
        return
u214892's avatar
u214892 committed
477

478
    def step(self, action_dict_: Dict[int, RailEnvActions]):
479
480
        """
        Updates rewards for the agents at a step.
481

482
483
484
        Parameters
        ----------
        action_dict_ : Dict[int,RailEnvActions]
485

486
        """
spiglerg's avatar
spiglerg committed
487
488
        self._elapsed_steps += 1

489
        # If we're done, set reward and info_dict and step() is done.
490
        if self.dones["__all__"]:
491
            raise Exception("Episode is done, cannot call step()")
492

493
494
495
        # Reset the step rewards
        self.rewards_dict = dict()
        info_dict = {
496
497
498
499
            "action_required": {},
            "malfunction": {},
            "speed": {},
            "status": {},
500
        }
501
        have_all_agents_ended = True  # boolean flag to check if all agents are done
502

503
        self.motionCheck = ac.MotionCheck()  # reset the motion check
504

505
506
507
508
        if not self.close_following:
            for i_agent, agent in enumerate(self.agents):
                # Reset the step rewards
                self.rewards_dict[i_agent] = 0
509

510
511
                # Induce malfunction before we do a step, thus a broken agent can't move in this step
                self._break_agent(agent)
spiglerg's avatar
spiglerg committed
512

513
514
                # Perform step on the agent
                self._step_agent(i_agent, action_dict_.get(i_agent))
515

516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
                # manage the boolean flag to check if all agents are indeed done (or done_removed)
                have_all_agents_ended &= (agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED])

                # Build info dict
                info_dict["action_required"][i_agent] = self.action_required(agent)
                info_dict["malfunction"][i_agent] = agent.malfunction_data['malfunction']
                info_dict["speed"][i_agent] = agent.speed_data['speed']
                info_dict["status"][i_agent] = agent.status

                # Fix agents that finished their malfunction such that they can perform an action in the next step
                self._fix_agent_after_malfunction(agent)


        else:
            for i_agent, agent in enumerate(self.agents):
                # Reset the step rewards
                self.rewards_dict[i_agent] = 0

                # Induce malfunction before we do a step, thus a broken agent can't move in this step
                self._break_agent(agent)

                # Perform step on the agent
                self._step_agent_cf(i_agent, action_dict_.get(i_agent))

            # second loop: check for collisions / conflicts
            self.motionCheck.find_conflicts()

            # third loop: update positions
            for i_agent, agent in enumerate(self.agents):
                self._step_agent2_cf(i_agent)
546

547
548
549
550
551
552
553
554
555
556
557
                # manage the boolean flag to check if all agents are indeed done (or done_removed)
                have_all_agents_ended &= (agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED])

                # Build info dict
                info_dict["action_required"][i_agent] = self.action_required(agent)
                info_dict["malfunction"][i_agent] = agent.malfunction_data['malfunction']
                info_dict["speed"][i_agent] = agent.speed_data['speed']
                info_dict["status"][i_agent] = agent.status

                # Fix agents that finished their malfunction such that they can perform an action in the next step
                self._fix_agent_after_malfunction(agent)
558

559
        
560
561
562
563
        # NEW : REW: (END)
        if ((self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps)) \
            or have_all_agents_ended :
            
564
565
            for i_agent, agent in enumerate(self.agents):
                
566
                # agent done? (arrival_time is not None)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
567
                if agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED:
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
                    
                    # if agent arrived earlier or on time = 0
                    # if agent arrived later = -ve reward based on how late
                    reward = min(agent.latest_arrival - agent.arrival_time, 0)
                    self.rewards_dict[i_agent] += reward
                
                # Agents not done (arrival_time is None)
                else:
                    
                    # CANCELLED check (never departed)
                    if (agent.status == RailAgentStatus.READY_TO_DEPART):
                        reward = -1 * self.cancellation_factor * \
                            (agent.get_travel_time_on_shortest_path(self.distance_map) + 0) # 0 replaced with buffer
                        self.rewards_dict[i_agent] += reward

                    # Departed but never reached
                    if (agent.status == RailAgentStatus.ACTIVE):
585
                        reward = agent.get_current_delay(self._elapsed_steps, self.distance_map)
586
587
                        self.rewards_dict[i_agent] += reward
                
spmohanty's avatar
spmohanty committed
588
                self.dones[i_agent] = True
589

590
            self.dones["__all__"] = True
591
        
592

593
        if self.record_steps:
594
            self.record_timestep(action_dict_)
595

u214892's avatar
u214892 committed
596
        return self._get_observations(), self.rewards_dict, self.dones, info_dict
597

598
    def _step_agent(self, i_agent, action: Optional[RailEnvActions] = None):
599
600
601
602
603
        """
        Performs a step and step, start and stop penalty on a single agent in the following sub steps:
        - malfunction
        - action handling if at the beginning of cell
        - movement
604

605
606
607
608
609
610
        Parameters
        ----------
        i_agent : int
        action_dict_ : Dict[int,RailEnvActions]

        """
u214892's avatar
u214892 committed
611
        agent = self.agents[i_agent]
u214892's avatar
u214892 committed
612
        if agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]:  # this agent has already completed...
613
614
            return

u214892's avatar
u214892 committed
615
616
        # agent gets active by a MOVE_* action and if c
        if agent.status == RailAgentStatus.READY_TO_DEPART:
617
618
619
620
            initial_cell_free = self.cell_free(agent.initial_position)
            is_action_starting = action in [
                RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_FORWARD]

u214892's avatar
u214892 committed
621
            if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT,
622
                          RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position):
u214892's avatar
u214892 committed
623
                agent.status = RailAgentStatus.ACTIVE
624
                self._set_agent_to_initial_position(agent, agent.initial_position)
625
                self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
626
                return
u214892's avatar
u214892 committed
627
            else:
628
629
                # TODO: Here we need to check for the departure time in future releases with full schedules
                self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
u214892's avatar
u214892 committed
630
631
                return

632
633
634
        agent.old_direction = agent.direction
        agent.old_position = agent.position

635
636
        # if agent is broken, actions are ignored and agent does not move.
        # full step penalty in this case
637
        if agent.malfunction_data['malfunction'] > 0:
638
639
640
641
            self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
            return

        # Is the agent at the beginning of the cell? Then, it can take an action.
642
        # As long as the agent is malfunctioning or stopped at the beginning of the cell,
643
        # different actions may be taken!
644
        if fast_isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03):
645
            # No action has been supplied for this agent -> set DO_NOTHING as default
646
            if action is None:
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
                action = RailEnvActions.DO_NOTHING

            if action < 0 or action > len(RailEnvActions):
                print('ERROR: illegal action=', action,
                      'for agent with index=', i_agent,
                      '"DO NOTHING" will be executed instead')
                action = RailEnvActions.DO_NOTHING

            if action == RailEnvActions.DO_NOTHING and agent.moving:
                # Keep moving
                action = RailEnvActions.MOVE_FORWARD

            if action == RailEnvActions.STOP_MOVING and agent.moving:
                # Only allow halting an agent on entering new cells.
                agent.moving = False
                self.rewards_dict[i_agent] += self.stop_penalty

            if not agent.moving and not (
665
666
                action == RailEnvActions.DO_NOTHING or
                action == RailEnvActions.STOP_MOVING):
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
                # Allow agent to start with any forward or direction action
                agent.moving = True
                self.rewards_dict[i_agent] += self.start_penalty

            # Store the action if action is moving
            # If not moving, the action will be stored when the agent starts moving again.
            if agent.moving:
                _action_stored = False
                _, new_cell_valid, new_direction, new_position, transition_valid = \
                    self._check_action_on_agent(action, agent)

                if all([new_cell_valid, transition_valid]):
                    agent.speed_data['transition_action_on_cellexit'] = action
                    _action_stored = True
                else:
                    # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
                    # try to keep moving forward!
                    if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT):
                        _, new_cell_valid, new_direction, new_position, transition_valid = \
                            self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)

                        if all([new_cell_valid, transition_valid]):
                            agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
                            _action_stored = True

                if not _action_stored:
                    # If the agent cannot move due to an invalid transition, we set its state to not moving
                    self.rewards_dict[i_agent] += self.invalid_action_penalty
                    self.rewards_dict[i_agent] += self.stop_penalty
                    agent.moving = False

        # Now perform a movement.
        # If agent.moving, increment the position_fraction by the speed of the agent
        # If the new position fraction is >= 1, reset to 0, and perform the stored
        #   transition_action_on_cellexit if the cell is free.
        if agent.moving:
            agent.speed_data['position_fraction'] += agent.speed_data['speed']
704
705
            if agent.speed_data['position_fraction'] > 1.0 or fast_isclose(agent.speed_data['position_fraction'], 1.0,
                                                                           rtol=1e-03):
706
707
708
709
                # Perform stored action to transition to the next cell as soon as cell is free
                # Notice that we've already checked new_cell_valid and transition valid when we stored the action,
                # so we only have to check cell_free now!

710
                # Traditional check that next cell is free
711
712
                # cell and transition validity was checked when we stored transition_action_on_cellexit!
                cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent(
713
                    agent.speed_data['transition_action_on_cellexit'], agent)
714
715
716
717
718

                # N.B. validity of new_cell and transition should have been verified before the action was stored!
                assert new_cell_valid
                assert transition_valid
                if cell_free:
719
                    self._move_agent_to_new_position(agent, new_position)
720
721
                    agent.direction = new_direction
                    agent.speed_data['position_fraction'] = 0.0
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736

            # has the agent reached its target?
            if np.equal(agent.position, agent.target).all():
                agent.status = RailAgentStatus.DONE
                self.dones[i_agent] = True
                self.active_agents.remove(i_agent)
                agent.moving = False
                self._remove_agent_from_scene(agent)
            else:
                self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
        else:
            # step penalty if not moving (stopped now or before)
            self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']

    def _step_agent_cf(self, i_agent, action: Optional[RailEnvActions] = None):
737
738
        """ "close following" version of step_agent.
        """
739
740
741
742
        agent = self.agents[i_agent]
        if agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]:  # this agent has already completed...
            return

743
744
745
        # NEW : STEP: WAITING > WAITING or WAITING > READY_TO_DEPART
        if (agent.status == RailAgentStatus.WAITING):
            if ( self._elapsed_steps >= agent.earliest_departure ):
746
                agent.status = RailAgentStatus.READY_TO_DEPART
747
748
749
            self.motionCheck.addAgent(i_agent, None, None)
            return

750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
        # agent gets active by a MOVE_* action and if c
        if agent.status == RailAgentStatus.READY_TO_DEPART:
            is_action_starting = action in [
                RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_FORWARD]

            if is_action_starting:  # agent is trying to start
                self.motionCheck.addAgent(i_agent, None, agent.initial_position)
            else:  # agent wants to remain unstarted
                self.motionCheck.addAgent(i_agent, None, None)
            return

        agent.old_direction = agent.direction
        agent.old_position = agent.position

        # if agent is broken, actions are ignored and agent does not move.
        # full step penalty in this case
766
767
        # TODO: this means that deadlocked agents which suffer a malfunction are marked as 
        # stopped rather than deadlocked.
768
769
        if agent.malfunction_data['malfunction'] > 0:
            self.motionCheck.addAgent(i_agent, agent.position, agent.position)
770
            # agent will get penalty in step_agent2_cf
Erik Nygren's avatar
Erik Nygren committed
771
            # self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
            return

        # Is the agent at the beginning of the cell? Then, it can take an action.
        # As long as the agent is malfunctioning or stopped at the beginning of the cell,
        # different actions may be taken!
        if np.isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03):
            # No action has been supplied for this agent -> set DO_NOTHING as default
            if action is None:
                action = RailEnvActions.DO_NOTHING

            if action < 0 or action > len(RailEnvActions):
                print('ERROR: illegal action=', action,
                      'for agent with index=', i_agent,
                      '"DO NOTHING" will be executed instead')
                action = RailEnvActions.DO_NOTHING

            if action == RailEnvActions.DO_NOTHING and agent.moving:
                # Keep moving
                action = RailEnvActions.MOVE_FORWARD

            if action == RailEnvActions.STOP_MOVING and agent.moving:
                # Only allow halting an agent on entering new cells.
                agent.moving = False
                self.rewards_dict[i_agent] += self.stop_penalty

            if not agent.moving and not (
                action == RailEnvActions.DO_NOTHING or
                action == RailEnvActions.STOP_MOVING):
                # Allow agent to start with any forward or direction action
                agent.moving = True
                self.rewards_dict[i_agent] += self.start_penalty
803

804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
            # Store the action if action is moving
            # If not moving, the action will be stored when the agent starts moving again.
            new_position = None
            if agent.moving:
                _action_stored = False
                _, new_cell_valid, new_direction, new_position, transition_valid = \
                    self._check_action_on_agent(action, agent)

                if all([new_cell_valid, transition_valid]):
                    agent.speed_data['transition_action_on_cellexit'] = action
                    _action_stored = True
                else:
                    # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
                    # try to keep moving forward!
                    if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT):
                        _, new_cell_valid, new_direction, new_position, transition_valid = \
                            self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)

                        if all([new_cell_valid, transition_valid]):
                            agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
                            _action_stored = True

                if not _action_stored:
                    # If the agent cannot move due to an invalid transition, we set its state to not moving
                    self.rewards_dict[i_agent] += self.invalid_action_penalty
                    self.rewards_dict[i_agent] += self.stop_penalty
                    agent.moving = False
831
                    self.motionCheck.addAgent(i_agent, agent.position, agent.position)
832
833
834
835
836
837
                    return

            if new_position is None:
                self.motionCheck.addAgent(i_agent, agent.position, agent.position)
                if agent.moving:
                    print("Agent", i_agent, "new_pos none, but moving")
838

839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
        # Check the pos_frac position fraction
        if agent.moving:
            agent.speed_data['position_fraction'] += agent.speed_data['speed']
            if agent.speed_data['position_fraction'] > 0.999:
                stored_action = agent.speed_data["transition_action_on_cellexit"]

                # find the next cell using the stored action
                _, new_cell_valid, new_direction, new_position, transition_valid = \
                    self._check_action_on_agent(stored_action, agent)

                # if it's valid, record it as the new position
                if all([new_cell_valid, transition_valid]):
                    self.motionCheck.addAgent(i_agent, agent.position, new_position)
                else:  # if the action wasn't valid then record the agent as stationary
                    self.motionCheck.addAgent(i_agent, agent.position, agent.position)
            else:  # This agent hasn't yet crossed the cell
                self.motionCheck.addAgent(i_agent, agent.position, agent.position)

    def _step_agent2_cf(self, i_agent):
        agent = self.agents[i_agent]

860
        # NEW : REW: (WAITING) no reward during WAITING...
861
        if agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED, RailAgentStatus.WAITING]:
862
863
864
865
866
867
            return

        (move, rc_next) = self.motionCheck.check_motion(i_agent, agent.position)

        if agent.position is not None:
            sbTrans = format(self.rail.grid[agent.position], "016b")
Erik Nygren's avatar
Erik Nygren committed
868
            trans_block = sbTrans[agent.direction * 4: agent.direction * 4 + 4]
869
870
871
872
            if (trans_block == "0000"):
                print (i_agent, agent.position, agent.direction, sbTrans, trans_block)

        # if agent cannot enter env, then we should have move=False
873

874
875
        if move:
            if agent.position is None:  # agent is entering the env
Erik Nygren's avatar
Erik Nygren committed
876
                # print(i_agent, "writing new pos ", rc_next, " into agent position (None)")
877
878
879
880
881
882
883
                agent.position = rc_next
                agent.status = RailAgentStatus.ACTIVE
                agent.speed_data['position_fraction'] = 0.0

            else:  # normal agent move
                cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent(
                    agent.speed_data['transition_action_on_cellexit'], agent)
884

885
886
887
888
889
                if not all([transition_valid, new_cell_valid]):
                    print(f"ERRROR: step_agent2 invalid transition ag {i_agent} dir {new_direction} pos {agent.position} next {rc_next}")

                if new_position != rc_next:
                    print(f"ERROR: agent {i_agent} new_pos {new_position} != rc_next {rc_next}  " + 
890
891
                          f"pos {agent.position} dir {agent.direction} new_dir {new_direction}" +
                          f"stored action: {agent.speed_data['transition_action_on_cellexit']}")
892
893

                sbTrans = format(self.rail.grid[agent.position], "016b")
Erik Nygren's avatar
Erik Nygren committed
894
                trans_block = sbTrans[agent.direction * 4: agent.direction * 4 + 4]
895
                if (trans_block == "0000"):
896
                    print ("ERROR: ", i_agent, agent.position, agent.direction, sbTrans, trans_block)
897
898
899

                agent.position = rc_next
                agent.direction = new_direction
900
901
                agent.speed_data['position_fraction'] = 0.0

902
            # NEW : STEP: Check DONE  before / after LA & Check if RUNNING before / after LA
903
904
            # has the agent reached its target?
            if np.equal(agent.position, agent.target).all():
905
906
907
908
909
910
911
                # arrived before or after Latest Arrival
                agent.status = RailAgentStatus.DONE
                self.dones[i_agent] = True
                self.active_agents.remove(i_agent)
                agent.moving = False
                agent.arrival_time = self._elapsed_steps
                self._remove_agent_from_scene(agent)
912
913
914
915
916
917
918

            else: # not reached its target and moving
                # running before Latest Arrival
                if (self._elapsed_steps <= agent.latest_arrival):
                    self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
                else: # running after Latest Arrival
                    self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] # + # NEGATIVE REWARD? per step?
919
        else:
920
921
922
923
            # stopped (!move) before Latest Arrival
            if (self._elapsed_steps <= agent.latest_arrival):
                self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
            else:  # stopped (!move) after Latest Arrival
924
                self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']  # + # NEGATIVE REWARD? per step?
925

926
    def _set_agent_to_initial_position(self, agent: EnvAgent, new_position: IntVector2D):
927
928
929
930
931
932
933
934
935
        """
        Sets the agent to its initial position. Updates the agent object and the position
        of the agent inside the global agent_position numpy array

        Parameters
        -------
        agent: EnvAgent object
        new_position: IntVector2D
        """
936
        agent.position = new_position
937
        self.agent_positions[agent.position] = agent.handle
938
939

    def _move_agent_to_new_position(self, agent: EnvAgent, new_position: IntVector2D):
940
941
942
943
944
945
946
947
948
        """
        Move the agent to the a new position. Updates the agent object and the position
        of the agent inside the global agent_position numpy array

        Parameters
        -------
        agent: EnvAgent object
        new_position: IntVector2D
        """
949
        agent.position = new_position
950
951
        self.agent_positions[agent.old_position] = -1
        self.agent_positions[agent.position] = agent.handle
952
953

    def _remove_agent_from_scene(self, agent: EnvAgent):
954
955
956
957
958
959
960
961
        """
        Remove the agent from the scene. Updates the agent object and the position
        of the agent inside the global agent_position numpy array

        Parameters
        -------
        agent: EnvAgent object
        """
962
        self.agent_positions[agent.position] = -1
963
964
        if self.remove_agents_at_target:
            agent.position = None
965
            # setting old_position to None here stops the DONE agents from appearing in the rendered image
966
            agent.old_position = None
967
968
            agent.status = RailAgentStatus.DONE_REMOVED

969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
    def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent):
        """

        Parameters
        ----------
        action : RailEnvActions
        agent : EnvAgent

        Returns
        -------
        bool
            Is it a legal move?
            1) transition allows the new_direction in the cell,
            2) the new cell is not empty (case 0),
            3) the cell is free, i.e., no agent is currently in that cell

u214892's avatar
u214892 committed
985

986
        """
u214892's avatar
u214892 committed
987
988
        # compute number of possible transitions in the current
        # cell used to check for invalid actions
989
        new_direction, transition_valid = self.check_action(agent, action)
u214892's avatar
u214892 committed
990
        new_position = get_new_position(agent.position, new_direction)
991

992
        new_cell_valid = (
993
            fast_position_equal(  # Check the new position is still in the grid
spiglerg's avatar
spiglerg committed
994
                new_position,
995
                fast_clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
spiglerg's avatar
spiglerg committed
996
            and  # check the new position has some transitions (ie is not an empty cell)
u214892's avatar
u214892 committed
997
            self.rail.get_full_transitions(*new_position) > 0)
998

spiglerg's avatar
spiglerg committed
999
        # If transition validity hasn't been checked yet.
1000
1001
        if transition_valid is None:
            transition_valid = self.rail.get_transition(
spiglerg's avatar
spiglerg committed
1002
1003
                (*agent.position, agent.direction),
                new_direction)