rail_env.py 28.7 KB
Newer Older
1
"""
2
Definition of the RailEnv environment.
3
"""
4
import random
5

6
from typing import List, Optional, Dict, Tuple
7

8
9
10
import numpy as np
from gym.utils import seeding
from dataclasses import dataclass
11

12
from flatland.core.env import Environment
13
from flatland.core.env_observation_builder import ObservationBuilder
14
from flatland.core.grid.grid4 import Grid4Transitions
u214892's avatar
u214892 committed
15
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
16
from flatland.core.transition_map import GridTransitionMap
17
from flatland.envs.agent_utils import EnvAgent
18
from flatland.envs.distance_map import DistanceMap
19
from flatland.envs.rail_env_action import RailEnvActions
hagrid67's avatar
hagrid67 committed
20

21
22
from flatland.envs import malfunction_generators as mal_gen
from flatland.envs import rail_generators as rail_gen
23
from flatland.envs import line_generators as line_gen
24
from flatland.envs.timetable_generators import timetable_generator
hagrid67's avatar
hagrid67 committed
25
from flatland.envs import persistence
26
from flatland.envs import agent_chains as ac
hagrid67's avatar
hagrid67 committed
27

28
from flatland.envs.observations import GlobalObsForRailEnv
hagrid67's avatar
hagrid67 committed
29

30
from flatland.envs.timetable_generators import timetable_generator
31
32
33
from flatland.envs.step_utils.states import TrainState, StateTransitionSignals
from flatland.envs.step_utils import transition_utils
from flatland.envs.step_utils import action_preprocessing
34

35
36
37
38
39
40
41
42
43
44
class RailEnv(Environment):
    """
    RailEnv environment class.

    RailEnv is an environment inspired by a (simplified version of) a rail
    network, in which agents (trains) have to navigate to their target
    locations in the shortest time possible, while at the same time cooperating
    to avoid bottlenecks.

    The valid actions in the environment are:
45
46
47
48
49
50

     -   0: do nothing (continue moving or stay still)
     -   1: turn left at switch and move to the next cell; if the agent was not moving, movement is started
     -   2: move to the next cell in front of the agent; if the agent was not moving, movement is started
     -   3: turn right at switch and move to the next cell; if the agent was not moving, movement is started
     -   4: stop moving
51
52
53
54

    Moving forward in a dead-end cell makes the agent turn 180 degrees and step
    to the cell it came from.

55

56
57
58
    The actions of the agents are executed in order of their handle to prevent
    deadlocks and to allow them to learn relative priorities.

59
60
61
62
63
64
65
66
67
68
69
70
    Reward Function:

    It costs each agent a step_penalty for every time-step taken in the environment. Independent of the movement
    of the agent. Currently all other penalties such as penalty for stopping, starting and invalid actions are set to 0.

    alpha = 1
    beta = 1
    Reward function parameters:

    - invalid_action_penalty = 0
    - step_penalty = -alpha
    - global_reward = beta
71
    - epsilon = avoid rounding errors
72
73
74
    - stop_penalty = 0  # penalty for stopping a moving agent
    - start_penalty = 0  # penalty for starting a stopped agent

75
76
    Stochastic malfunctioning of trains:
    Trains in RailEnv can malfunction if they are halted too often (either by their own choice or because an invalid
spiglerg's avatar
spiglerg committed
77
78
    action or cell is selected.

79
80
81
    Every time an agent stops, an agent has a certain probability of malfunctioning. Malfunctions of trains follow a
    poisson process with a certain rate. Not all trains will be affected by malfunctions during episodes to keep
    complexity managable.
spiglerg's avatar
spiglerg committed
82
83
84
85

    TODO: currently, the parameters that control the stochasticity of the environment are hard-coded in init().
    For Round 2, they will be passed to the constructor as arguments, to allow for more flexibility.

86
    """
u214892's avatar
u214892 committed
87
88
    # Epsilon to avoid rounding errors
    epsilon = 0.01
89
90
91
    # NEW : REW: Sparse Reward
    alpha = 0
    beta = 0
u214892's avatar
u214892 committed
92
93
    step_penalty = -1 * alpha
    global_reward = 1 * beta
94
    invalid_action_penalty = 0  # previously -2; GIACOMO: we decided that invalid actions will carry no penalty
u214892's avatar
u214892 committed
95
96
    stop_penalty = 0  # penalty for stopping a moving agent
    start_penalty = 0  # penalty for starting a stopped agent
97
98
    cancellation_factor = 1
    cancellation_time_buffer = 0
99
100
101
102

    def __init__(self,
                 width,
                 height,
103
                 rail_generator=None,
104
                 line_generator=None,  # : line_gen.LineGenerator = line_gen.random_line_generator(),
105
                 number_of_agents=2,
Erik Nygren's avatar
Erik Nygren committed
106
                 obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(),
Erik Nygren's avatar
Erik Nygren committed
107
                 malfunction_generator_and_process_data=None,  # mal_gen.no_malfunction_generator(),
108
                 malfunction_generator=None,
109
                 remove_agents_at_target=True,
110
                 random_seed=1,
111
112
                 record_steps=False,
                 close_following=True
u214892's avatar
u214892 committed
113
                 ):
114
115
116
117
        """
        Environment init.

        Parameters
118
        ----------
119
        rail_generator : function
120
121
122
123
            The rail_generator function is a function that takes the width,
            height and agents handles of a  rail environment, along with the number of times
            the env has been reset, and returns a GridTransitionMap object and a list of
            starting positions, targets, and initial orientations for agent handle.
124
            The rail_generator can pass a distance map in the hints or information for specific line_generators.
u214892's avatar
u214892 committed
125
            Implementations can be found in flatland/envs/rail_generators.py
126
127
        line_generator : function
            The line_generator function is a function that takes the grid, the number of agents and optional hints
u214892's avatar
u214892 committed
128
            and returns a list of starting positions, targets, initial orientations and speed for all agent handles.
129
            Implementations can be found in flatland/envs/line_generators.py
130
131
132
133
134
135
136
137
138
139
140
141
        width : int
            The width of the rail map. Potentially in the future,
            a range of widths to sample from.
        height : int
            The height of the rail map. Potentially in the future,
            a range of heights to sample from.
        number_of_agents : int
            Number of agents to spawn on the map. Potentially in the future,
            a range of number of agents to sample from.
        obs_builder_object: ObservationBuilder object
            ObservationBuilder-derived object that takes builds observation
            vectors for each agent.
142
143
144
        remove_agents_at_target : bool
            If remove_agents_at_target is set to true then the agents will be removed by placing to
            RailEnv.DEPOT_POSITION when the agent has reach it's target position.
145
146
147
        random_seed : int or None
            if None, then its ignored, else the random generators are seeded with this number to ensure
            that stochastic operations are replicable across multiple operations
148
        """
149
        super().__init__()
150

151
152
153
154
155
156
        if malfunction_generator_and_process_data is not None:
            print("DEPRECATED - RailEnv arg: malfunction_and_process_data - use malfunction_generator")
            self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
        elif malfunction_generator is not None:
            self.malfunction_generator = malfunction_generator
            # malfunction_process_data is not used
Erik Nygren's avatar
Erik Nygren committed
157
            # self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
158
159
160
161
162
            self.malfunction_process_data = self.malfunction_generator.get_process_data()
        # replace default values here because we can't use default args values because of cyclic imports
        else:
            self.malfunction_generator = mal_gen.NoMalfunctionGen()
            self.malfunction_process_data = self.malfunction_generator.get_process_data()
163
        
164
        self.number_of_agents = number_of_agents
165

Erik Nygren's avatar
Erik Nygren committed
166
        # self.rail_generator: RailGenerator = rail_generator
167
        if rail_generator is None:
168
            rail_generator = rail_gen.sparse_rail_generator()
169
        self.rail_generator = rail_generator
170
        if line_generator is None:
171
            line_generator = line_gen.sparse_line_generator()
172
        self.line_generator = line_generator
173

u214892's avatar
u214892 committed
174
        self.rail: Optional[GridTransitionMap] = None
175
176
        self.width = width
        self.height = height
Erik Nygren's avatar
Erik Nygren committed
177

178
179
        self.remove_agents_at_target = remove_agents_at_target

Erik Nygren's avatar
Erik Nygren committed
180
        self.rewards = [0] * number_of_agents
181
        self.done = False
182
        self.obs_builder = obs_builder_object
u229589's avatar
u229589 committed
183
        self.obs_builder.set_env(self)
184

185
        self._max_episode_steps: Optional[int] = None
spiglerg's avatar
spiglerg committed
186
187
        self._elapsed_steps = 0

Erik Nygren's avatar
Erik Nygren committed
188
        self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False)
189

190
191
        self.obs_dict = {}
        self.rewards_dict = {}
Erik Nygren's avatar
Erik Nygren committed
192
        self.dev_obs_dict = {}
Erik Nygren's avatar
Erik Nygren committed
193
        self.dev_pred_dict = {}
194

u229589's avatar
u229589 committed
195
        self.agents: List[EnvAgent] = []
196
        self.num_resets = 0
197
        self.distance_map = DistanceMap(self.agents, self.height, self.width)
198

Erik Nygren's avatar
Erik Nygren committed
199
        self.action_space = [5]
200

201
        self._seed()
202
203
204
205
206
        self._seed()
        self.random_seed = random_seed
        if self.random_seed:
            self._seed(seed=random_seed)

207
208
        self.valid_positions = None

209
210
        # global numpy array of agents position, True means that there is an agent at that cell
        self.agent_positions: np.ndarray = np.full((height, width), False)
211

212
213
        # save episode timesteps ie agent positions, orientations.  (not yet actions / observations)
        self.record_steps = record_steps  # whether to save timesteps
214
        # save timesteps in here: [[[row, col, dir, malfunction],...nAgents], ...nSteps]
215
216
        self.cur_episode = []
        self.list_actions = []  # save actions in here
217

218
        self.close_following = close_following  # use close following logic
219
220
        self.motionCheck = ac.MotionCheck()

221
222
        self.agent_helpers = {}

223
224
    def _seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
225
        random.seed(seed)
226
227
        return [seed]

228
    # no more agent_handles
229
    def get_agent_handles(self):
230
231
        return range(self.get_num_agents())

u229589's avatar
u229589 committed
232
233
    def get_num_agents(self) -> int:
        return len(self.agents)
234

u229589's avatar
u229589 committed
235
    def add_agent(self, agent):
hagrid67's avatar
hagrid67 committed
236
237
238
        """ Add static info for a single agent.
            Returns the index of the new agent.
        """
u229589's avatar
u229589 committed
239
240
        self.agents.append(agent)
        return len(self.agents) - 1
hagrid67's avatar
hagrid67 committed
241

Erik Nygren's avatar
Erik Nygren committed
242
    def reset_agents(self):
u229589's avatar
u229589 committed
243
        """ Reset the agents to their starting positions
hagrid67's avatar
hagrid67 committed
244
        """
u229589's avatar
u229589 committed
245
246
        for agent in self.agents:
            agent.reset()
247
        self.active_agents = [i for i in range(len(self.agents))]
Erik Nygren's avatar
Erik Nygren committed
248

Erik Nygren's avatar
Erik Nygren committed
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    def action_required(self, agent):
        """
        Check if an agent needs to provide an action

        Parameters
        ----------
        agent: RailEnvAgent
        Agent we want to check

        Returns
        -------
        True: Agent needs to provide an action
        False: Agent cannot provide an action
        """
263
        return agent.state == TrainState.READY_TO_DEPART or \
Dipam Chakraborty's avatar
Dipam Chakraborty committed
264
               ( agent.state.is_on_map_state() and agent.speed_counter.is_cell_entry )
Erik Nygren's avatar
Erik Nygren committed
265

266
    def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, *,
267
              random_seed: bool = None) -> Tuple[Dict, Dict]:
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
        """
        reset(regenerate_rail, regenerate_schedule, activate_agents, random_seed)

        The method resets the rail environment

        Parameters
        ----------
        regenerate_rail : bool, optional
            regenerate the rails
        regenerate_schedule : bool, optional
            regenerate the schedule and the static agents
        random_seed : bool, optional
            random seed for environment

        Returns
        -------
        observation_dict: Dict
            Dictionary with an observation for each agent
        info_dict: Dict with agent specific information

hagrid67's avatar
hagrid67 committed
288
        """
289

290
291
        if random_seed:
            self._seed(random_seed)
292

293
        optionals = {}
294
        if regenerate_rail or self.rail is None:
295
296
297
298
299
300
301
302
303
304

            if "__call__" in dir(self.rail_generator):
                rail, optionals = self.rail_generator(
                    self.width, self.height, self.number_of_agents, self.num_resets, self.np_random)
            elif "generate" in dir(self.rail_generator):
                rail, optionals = self.rail_generator.generate(
                    self.width, self.height, self.number_of_agents, self.num_resets, self.np_random)
            else:
                raise ValueError("Could not invoke __call__ or generate on rail_generator")

u214892's avatar
u214892 committed
305
            self.rail = rail
306
            self.height, self.width = self.rail.grid.shape
307

spmohanty's avatar
spmohanty committed
308
            # Do a new set_env call on the obs_builder to ensure
309
            # that obs_builder specific instantiations are made according to the
spmohanty's avatar
spmohanty committed
310
311
312
            # specifications of the current environment : like width, height, etc
            self.obs_builder.set_env(self)

313
        if optionals and 'distance_map' in optionals:
314
            self.distance_map.set(optionals['distance_map'])
315

316
        if regenerate_schedule or regenerate_rail or self.get_num_agents() == 0:
u214892's avatar
u214892 committed
317
318
319
            agents_hints = None
            if optionals and 'agents_hints' in optionals:
                agents_hints = optionals['agents_hints']
320

321
            line = self.line_generator(self.rail, self.number_of_agents, agents_hints, 
322
                                               self.num_resets, self.np_random)
323
            self.agents = EnvAgent.from_line(line)
324

325
326
            # Reset distance map - basically initializing
            self.distance_map.reset(self.agents, self.rail)
327

328
            # NEW : Time Schedule Generation
329
            timetable = timetable_generator(self.agents, self.distance_map, 
330
                                               agents_hints, self.np_random)
331

332
            self._max_episode_steps = timetable.max_episode_steps
333

334
            for agent_i, agent in enumerate(self.agents):
335
336
                agent.earliest_departure = timetable.earliest_departures[agent_i]         
                agent.latest_arrival = timetable.latest_arrivals[agent_i]
337
        else:
338
            self.distance_map.reset(self.agents, self.rail)
339

340
341
        # Agent Positions Map
        self.agent_positions = np.zeros((self.height, self.width), dtype=int) - 1
342
343
344
        
        # Reset agents to initial states
        self.reset_agents()
hagrid67's avatar
hagrid67 committed
345

346
        self.num_resets += 1
spiglerg's avatar
spiglerg committed
347
        self._elapsed_steps = 0
348

u214892's avatar
u214892 committed
349
        # TODO perhaps dones should be part of each agent.
350
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
351

352
353
354
        # Reset the state of the observation builder with the new environment
        self.obs_builder.reset()

355
356
357
        # Empty the episode store of agent positions
        self.cur_episode = []

358
        info_dict = self.get_info_dict()
359
        # Return the new observation vectors for each agent
360
361
        observation_dict: Dict = self._get_observations()
        return observation_dict, info_dict
Dipam Chakraborty's avatar
Dipam Chakraborty committed
362
363
    
    def apply_action_independent(self, action, rail, position, direction):
Dipam Chakraborty's avatar
Dipam Chakraborty committed
364
        if action.is_moving_action():
365
            new_direction, _ = transition_utils.check_action(action, position, direction, rail)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
366
367
368
            new_position = get_new_position(position, new_direction)
        else:
            new_position, new_direction = position, direction
369
        return new_position, new_direction
370
371
    
    def generate_state_transition_signals(self, agent, preprocessed_action, movement_allowed):
372
        """ Generate State Transitions Signals used in the state machine """
373
        st_signals = StateTransitionSignals()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
374
        
375
        # Malfunction starts when in_malfunction is set to true
376
        st_signals.in_malfunction = agent.malfunction_handler.in_malfunction
377
378
379
380
381
382
383
384
385
386
387

        # Malfunction counter complete - Malfunction ends next timestep
        st_signals.malfunction_counter_complete = agent.malfunction_handler.malfunction_counter_complete

        # Earliest departure reached - Train is allowed to move now
        st_signals.earliest_departure_reached = self._elapsed_steps >= agent.earliest_departure

        # Stop Action Given
        st_signals.stop_action_given = (preprocessed_action == RailEnvActions.STOP_MOVING)

        # Valid Movement action Given
Dipam Chakraborty's avatar
Dipam Chakraborty committed
388
        st_signals.valid_movement_action_given = preprocessed_action.is_moving_action() and movement_allowed
389
390
391
392
393

        # Target Reached
        st_signals.target_reached = fast_position_equal(agent.position, agent.target)

        # Movement conflict - Multiple trains trying to move into same cell
Dipam Chakraborty's avatar
Dipam Chakraborty committed
394
395
        # If speed counter is not in cell exit, the train can enter the cell
        st_signals.movement_conflict = (not movement_allowed) and agent.speed_counter.is_cell_exit
396

397
398
        return st_signals

399
400
401
402
403
404
405
406
407
408
    def _handle_end_reward(self, agent: EnvAgent) -> int:
        '''
        Handles end-of-episode reward for a particular agent.

        Parameters
        ----------
        agent : EnvAgent
        '''
        reward = None
        # agent done? (arrival_time is not None)
409
        if agent.state == TrainState.DONE:
410
411
412
413
414
415
416
            # if agent arrived earlier or on time = 0
            # if agent arrived later = -ve reward based on how late
            reward = min(agent.latest_arrival - agent.arrival_time, 0)

        # Agents not done (arrival_time is None)
        else:
            # CANCELLED check (never departed)
417
            if (agent.state.is_off_map_state()):
418
                reward = -1 * self.cancellation_factor * \
419
                    (agent.get_travel_time_on_shortest_path(self.distance_map) + self.cancellation_time_buffer)
420
421

            # Departed but never reached
422
            if (agent.state.is_on_map_state()):
423
424
425
426
                reward = agent.get_current_delay(self._elapsed_steps, self.distance_map)
        
        return reward

Dipam Chakraborty's avatar
Dipam Chakraborty committed
427
428
429
430
431
432
433
    def preprocess_action(self, action, agent):
        """
        Preprocess the provided action
            * Change to DO_NOTHING if illegal action
            * Block all actions when in waiting state
            * Check MOVE_LEFT/MOVE_RIGHT actions on current position else try MOVE_FORWARD
        """
434
        action = action_preprocessing.preprocess_raw_action(action, agent.state, agent.action_saver.saved_action)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
435
436
437
438
439
440
441
442
443
        action = action_preprocessing.preprocess_action_when_waiting(action, agent.state)

        # Try moving actions on current position
        current_position, current_direction = agent.position, agent.direction
        if current_position is None: # Agent not added on map yet
            current_position, current_direction = agent.initial_position, agent.initial_direction
        
        action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction)
        return action
444
445
    
    def clear_rewards_dict(self):
446
447
        """ Reset the rewards dictionary """
        self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))}
448
449

    def get_info_dict(self): # TODO Important : Update this
450
        info_dict = {
451
452
453
454
455
456
            'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)},
            'malfunction': {
                i: agent.malfunction_data['malfunction'] for i, agent in enumerate(self.agents)
            },
            'speed': {i: agent.speed_counter.speed for i, agent in enumerate(self.agents)},
            'state': {i: agent.state for i, agent in enumerate(self.agents)}
457
        }
458
        return info_dict
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
    
    def update_step_rewards(self, i_agent):
        pass

    def end_of_episode_update(self, have_all_agents_ended):
        if have_all_agents_ended or \
           ( (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps)):

            for i_agent, agent in enumerate(self.agents):
                
                reward = self._handle_end_reward(agent)
                self.rewards_dict[i_agent] += reward
                
                self.dones[i_agent] = True

            self.dones["__all__"] = True
475

476
477
478
479
480
481
    def handle_done_state(self, agent):
        if agent.state == TrainState.DONE:
            agent.arrival_time = self._elapsed_steps
            if self.remove_agents_at_target:
                agent.position = None

482
483
484
485
486
487
488
489
490
491
492
    def step(self, action_dict_: Dict[int, RailEnvActions]):
        """
        Updates rewards for the agents at a step.
        """
        self._elapsed_steps += 1

        # Not allowed to step further once done
        if self.dones["__all__"]:
            raise Exception("Episode is done, cannot call step()")

        self.clear_rewards_dict()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
493

494
495
        have_all_agents_ended = True # Boolean flag to check if all agents are done

496
497
        self.motionCheck = ac.MotionCheck()  # reset the motion check

498
        temp_transition_data = {}
499
        
Dipam Chakraborty's avatar
Dipam Chakraborty committed
500
501
        for agent in self.agents:
            i_agent = agent.handle
Dipam Chakraborty's avatar
Dipam Chakraborty committed
502
503
            agent.old_position = agent.position
            agent.old_direction = agent.direction
Dipam Chakraborty's avatar
Dipam Chakraborty committed
504
505
506
            # Generate malfunction
            agent.malfunction_handler.generate_malfunction(self.malfunction_generator, self.np_random)

507
            # Get action for the agent
508
            action = action_dict_.get(i_agent, RailEnvActions.DO_NOTHING)
509

Dipam Chakraborty's avatar
Dipam Chakraborty committed
510
            preprocessed_action = self.preprocess_action(action, agent)
511
512

            # Save moving actions in not already saved
Dipam Chakraborty's avatar
Dipam Chakraborty committed
513
            agent.action_saver.save_action_if_allowed(preprocessed_action, agent.state)
514

515
516
517
            # Train's next position can change if current stopped in a fractional speed or train is at cell's exit
            position_update_allowed = (agent.speed_counter.is_cell_exit or agent.state == TrainState.STOPPED)

518
519
            # Calculate new position
            # Add agent to the map if not on it yet
Dipam Chakraborty's avatar
Dipam Chakraborty committed
520
            if agent.position is None and agent.action_saver.is_action_saved:
521
522
                new_position = agent.initial_position
                new_direction = agent.initial_direction
523
                
524
            # If movement is allowed apply saved action independent of other agents
525
            elif agent.action_saver.is_action_saved and position_update_allowed:
Dipam Chakraborty's avatar
Dipam Chakraborty committed
526
                saved_action = agent.action_saver.saved_action
527
                # Apply action independent of other agents and get temporary new position and direction
Dipam Chakraborty's avatar
Dipam Chakraborty committed
528
529
530
531
                new_position, new_direction  = self.apply_action_independent(saved_action, 
                                                                             self.rail, 
                                                                             agent.position, 
                                                                             agent.direction)
532
                preprocessed_action = saved_action
Dipam Chakraborty's avatar
Dipam Chakraborty committed
533
            else:
534
                new_position, new_direction = agent.position, agent.direction
Dipam Chakraborty's avatar
Dipam Chakraborty committed
535

536
537
538
            temp_transition_data[i_agent] = AgentTransitionData(position=new_position,
                                                                direction=new_direction,
                                                                preprocessed_action=preprocessed_action)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
539
            
540
            # This is for storing and later checking for conflicts of agents trying to occupy same cell                                                    
541
            self.motionCheck.addAgent(i_agent, agent.position, new_position)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
542

543
        # Find conflicts between trains trying to occupy same cell
544
        self.motionCheck.find_conflicts()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
545
        
Dipam Chakraborty's avatar
Dipam Chakraborty committed
546
547
        for agent in self.agents:
            i_agent = agent.handle
548
            agent_transition_data = temp_transition_data[i_agent]
Dipam Chakraborty's avatar
Dipam Chakraborty committed
549

550
            ## Update positions
Dipam Chakraborty's avatar
Dipam Chakraborty committed
551
552
            if agent.malfunction_handler.in_malfunction:
                movement_allowed = False
553
            else:
554
                movement_allowed = self.motionCheck.check_motion(i_agent, agent.position)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
555

556
557
558
559
            # Position can be changed only if other cell is empty
            # And either the speed counter completes or agent is being added to map
            if movement_allowed and \
               (agent.speed_counter.is_cell_exit or agent.position is None):
560
561
                agent.position = agent_transition_data.position
                agent.direction = agent_transition_data.direction
562

563
            preprocessed_action = agent_transition_data.preprocessed_action
564
565

            ## Update states
566
567
568
            state_transition_signals = self.generate_state_transition_signals(agent, preprocessed_action, movement_allowed)
            agent.state_machine.set_transition_signals(state_transition_signals)
            agent.state_machine.step()
569

570
571
            # Off map or on map state and position should match
            state_position_sync_check(agent.state, agent.position, agent.handle)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
572

573
574
            # Handle done state actions, optionally remove agents
            self.handle_done_state(agent)
575
576
            
            have_all_agents_ended &= (agent.state == TrainState.DONE)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
577

578
            ## Update rewards
579
            self.update_step_rewards(i_agent)
580
581

            ## Update counters (malfunction and speed)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
582
            agent.speed_counter.update_counter(agent.state, agent.old_position)
583
                                            #    agent.state_machine.previous_state)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
584
            agent.malfunction_handler.update_counter()
585
586

            # Clear old action when starting in new cell
587
            if agent.speed_counter.is_cell_entry and agent.position is not None:
588
589
                agent.action_saver.clear_saved_action()
        
590
591
        # Check if episode has ended and update rewards and dones
        self.end_of_episode_update(have_all_agents_ended)
592

593
        return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict() 
Dipam Chakraborty's avatar
Dipam Chakraborty committed
594

595
    def record_timestep(self, dActions):
596
597
598
599
600
601
602
603
604
605
606
607
        ''' Record the positions and orientations of all agents in memory, in the cur_episode
        '''
        list_agents_state = []
        for i_agent in range(self.get_num_agents()):
            agent = self.agents[i_agent]
            # the int cast is to avoid numpy types which may cause problems with msgpack
            # in env v2, agents may have position None, before starting
            if agent.position is None:
                pos = (0, 0)
            else:
                pos = (int(agent.position[0]), int(agent.position[1]))
            # print("pos:", pos, type(pos[0]))
608
609
610
            list_agents_state.append([
                    *pos, int(agent.direction), 
                    agent.malfunction_data["malfunction"],  
611
612
                    int(agent.status),
                    int(agent.position in self.motionCheck.svDeadlocked)
613
                    ])
614

615
        self.cur_episode.append(list_agents_state)
616
        self.list_actions.append(dActions)
617

618
    def _get_observations(self):
619
620
621
622
623
624
625
        """
        Utility which returns the observations for an agent with respect to environment

        Returns
        ------
        Dict object
        """
626
        # print(f"_get_obs - num agents: {self.get_num_agents()} {list(range(self.get_num_agents()))}")
627
        self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
628
        return self.obs_dict
629

u214892's avatar
u214892 committed
630
    def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]:
631
632
633
634
635
636
637
638
639
640
641
642
        """
        Returns directions in which the agent can move

        Parameters:
        ---------
        row : int
        col : int

        Returns:
        -------
        List[int]
        """
643
        return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col))
u214892's avatar
u214892 committed
644

645
    def _exp_distirbution_synced(self, rate: float) -> float:
646
647
648
649
650
651
652
653
654
        """
        Generates sample from exponential distribution
        We need this to guarantee synchronity between different instances with same seed.
        :param rate:
        :return:
        """
        u = self.np_random.rand()
        x = - np.log(1 - u) * rate
        return x
655

656
    def _is_agent_ok(self, agent: EnvAgent) -> bool:
Erik Nygren's avatar
Erik Nygren committed
657
658
659
660
661
662
663
664
665
666
667
        """
        Check if an agent is ok, meaning it can move and is not malfuncitoinig
        Parameters
        ----------
        agent

        Returns
        -------
        True if agent is ok, False otherwise

        """
Dipam Chakraborty's avatar
Dipam Chakraborty committed
668
        return agent.malfunction_handler.in_malfunction
hagrid67's avatar
hagrid67 committed
669
670
671
672

    def save(self, filename):
        print("deprecated call to env.save() - pls call RailEnvPersister.save()")
        persistence.RailEnvPersister.save(self, filename)
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690

@dataclass(repr=True)
class AgentTransitionData:
    """ Class for keeping track of temporary agent data for position update """
    position : Tuple[int, int]
    direction : Grid4Transitions
    preprocessed_action : RailEnvActions


# Adrian Egli performance fix (the fast methods brings more than 50%)
def fast_isclose(a, b, rtol):
    return (a < (b + rtol)) or (a < (b - rtol))

def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
    if pos_1 is None: # TODO: Dipam - Consider making default of agent.position as (-1, -1) instead of None
        return False
    else:
        return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
691
692
693
694
695
696
697
698

def state_position_sync_check(state, position, i_agent):
    if state.is_on_map_state() and position is None:
        raise ValueError("Agent ID {} Agent State {} is on map Agent Position {} if off map ".format(
                        i_agent, str(state), str(position) ))
    elif state.is_off_map_state() and position is not None:
        raise ValueError("Agent ID {} Agent State {} is off map Agent Position {} if on map ".format(
                        i_agent, str(state), str(position) ))