rail_env.py 31.1 KB
Newer Older
1
"""
2
Definition of the RailEnv environment.
3
"""
4
import random
5

6
from typing import List, Optional, Dict, Tuple
7

8
9
import numpy as np
from gym.utils import seeding
10

Dipam Chakraborty's avatar
Dipam Chakraborty committed
11
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
12
from flatland.core.env import Environment
13
from flatland.core.env_observation_builder import ObservationBuilder
14
from flatland.core.grid.grid4 import Grid4Transitions
u214892's avatar
u214892 committed
15
from flatland.core.transition_map import GridTransitionMap
16
from flatland.envs.agent_utils import EnvAgent
17
from flatland.envs.distance_map import DistanceMap
18
from flatland.envs.rail_env_action import RailEnvActions
hagrid67's avatar
hagrid67 committed
19

20
21
from flatland.envs import malfunction_generators as mal_gen
from flatland.envs import rail_generators as rail_gen
22
from flatland.envs import line_generators as line_gen
23
from flatland.envs.timetable_generators import timetable_generator
hagrid67's avatar
hagrid67 committed
24
from flatland.envs import persistence
25
from flatland.envs import agent_chains as ac
hagrid67's avatar
hagrid67 committed
26

27
28
from flatland.envs.observations import GlobalObsForRailEnv

29
from flatland.envs.timetable_generators import timetable_generator
30
31
from flatland.envs.step_utils.states import TrainState, StateTransitionSignals
from flatland.envs.step_utils import action_preprocessing
32
from flatland.envs.step_utils import env_utils
u214892's avatar
u214892 committed
33

34
35
36
37
38
39
40
41
42
43
class RailEnv(Environment):
    """
    RailEnv environment class.

    RailEnv is an environment inspired by a (simplified version of) a rail
    network, in which agents (trains) have to navigate to their target
    locations in the shortest time possible, while at the same time cooperating
    to avoid bottlenecks.

    The valid actions in the environment are:
44
45
46
47
48
49

     -   0: do nothing (continue moving or stay still)
     -   1: turn left at switch and move to the next cell; if the agent was not moving, movement is started
     -   2: move to the next cell in front of the agent; if the agent was not moving, movement is started
     -   3: turn right at switch and move to the next cell; if the agent was not moving, movement is started
     -   4: stop moving
50
51
52
53

    Moving forward in a dead-end cell makes the agent turn 180 degrees and step
    to the cell it came from.

54

55
56
57
    The actions of the agents are executed in order of their handle to prevent
    deadlocks and to allow them to learn relative priorities.

58
59
60
61
62
63
64
65
66
67
68
69
    Reward Function:

    It costs each agent a step_penalty for every time-step taken in the environment. Independent of the movement
    of the agent. Currently all other penalties such as penalty for stopping, starting and invalid actions are set to 0.

    alpha = 1
    beta = 1
    Reward function parameters:

    - invalid_action_penalty = 0
    - step_penalty = -alpha
    - global_reward = beta
70
    - epsilon = avoid rounding errors
71
72
73
    - stop_penalty = 0  # penalty for stopping a moving agent
    - start_penalty = 0  # penalty for starting a stopped agent

74
75
    Stochastic malfunctioning of trains:
    Trains in RailEnv can malfunction if they are halted too often (either by their own choice or because an invalid
spiglerg's avatar
spiglerg committed
76
77
    action or cell is selected.

78
79
80
    Every time an agent stops, an agent has a certain probability of malfunctioning. Malfunctions of trains follow a
    poisson process with a certain rate. Not all trains will be affected by malfunctions during episodes to keep
    complexity managable.
spiglerg's avatar
spiglerg committed
81
82
83
84

    TODO: currently, the parameters that control the stochasticity of the environment are hard-coded in init().
    For Round 2, they will be passed to the constructor as arguments, to allow for more flexibility.

85
    """
u214892's avatar
u214892 committed
86
87
    # Epsilon to avoid rounding errors
    epsilon = 0.01
88
89
90
    # NEW : REW: Sparse Reward
    alpha = 0
    beta = 0
u214892's avatar
u214892 committed
91
92
    step_penalty = -1 * alpha
    global_reward = 1 * beta
93
    invalid_action_penalty = 0  # previously -2; GIACOMO: we decided that invalid actions will carry no penalty
u214892's avatar
u214892 committed
94
95
    stop_penalty = 0  # penalty for stopping a moving agent
    start_penalty = 0  # penalty for starting a stopped agent
96
97
    cancellation_factor = 1
    cancellation_time_buffer = 0
98
99
100
101

    def __init__(self,
                 width,
                 height,
102
                 rail_generator=None,
103
                 line_generator=None,  # : line_gen.LineGenerator = line_gen.random_line_generator(),
104
                 number_of_agents=2,
Erik Nygren's avatar
Erik Nygren committed
105
                 obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(),
Erik Nygren's avatar
Erik Nygren committed
106
                 malfunction_generator_and_process_data=None,  # mal_gen.no_malfunction_generator(),
107
                 malfunction_generator=None,
108
                 remove_agents_at_target=True,
Dipam Chakraborty's avatar
Dipam Chakraborty committed
109
                 random_seed=None,
110
                 record_steps=False,
u214892's avatar
u214892 committed
111
                 ):
112
113
114
115
        """
        Environment init.

        Parameters
116
        ----------
117
        rail_generator : function
118
119
120
121
            The rail_generator function is a function that takes the width,
            height and agents handles of a  rail environment, along with the number of times
            the env has been reset, and returns a GridTransitionMap object and a list of
            starting positions, targets, and initial orientations for agent handle.
122
            The rail_generator can pass a distance map in the hints or information for specific line_generators.
u214892's avatar
u214892 committed
123
            Implementations can be found in flatland/envs/rail_generators.py
124
125
        line_generator : function
            The line_generator function is a function that takes the grid, the number of agents and optional hints
u214892's avatar
u214892 committed
126
            and returns a list of starting positions, targets, initial orientations and speed for all agent handles.
127
            Implementations can be found in flatland/envs/line_generators.py
128
129
130
131
132
133
134
135
136
137
138
139
        width : int
            The width of the rail map. Potentially in the future,
            a range of widths to sample from.
        height : int
            The height of the rail map. Potentially in the future,
            a range of heights to sample from.
        number_of_agents : int
            Number of agents to spawn on the map. Potentially in the future,
            a range of number of agents to sample from.
        obs_builder_object: ObservationBuilder object
            ObservationBuilder-derived object that takes builds observation
            vectors for each agent.
140
141
142
        remove_agents_at_target : bool
            If remove_agents_at_target is set to true then the agents will be removed by placing to
            RailEnv.DEPOT_POSITION when the agent has reach it's target position.
143
144
145
        random_seed : int or None
            if None, then its ignored, else the random generators are seeded with this number to ensure
            that stochastic operations are replicable across multiple operations
146
        """
147
        super().__init__()
148

149
150
151
152
153
154
        if malfunction_generator_and_process_data is not None:
            print("DEPRECATED - RailEnv arg: malfunction_and_process_data - use malfunction_generator")
            self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
        elif malfunction_generator is not None:
            self.malfunction_generator = malfunction_generator
            # malfunction_process_data is not used
Erik Nygren's avatar
Erik Nygren committed
155
            # self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
156
157
158
159
160
            self.malfunction_process_data = self.malfunction_generator.get_process_data()
        # replace default values here because we can't use default args values because of cyclic imports
        else:
            self.malfunction_generator = mal_gen.NoMalfunctionGen()
            self.malfunction_process_data = self.malfunction_generator.get_process_data()
161
        
162
        self.number_of_agents = number_of_agents
163

164
        if rail_generator is None:
165
            rail_generator = rail_gen.sparse_rail_generator()
166
        self.rail_generator = rail_generator
167
        if line_generator is None:
168
            line_generator = line_gen.sparse_line_generator()
169
        self.line_generator = line_generator
170

u214892's avatar
u214892 committed
171
        self.rail: Optional[GridTransitionMap] = None
172
173
        self.width = width
        self.height = height
Erik Nygren's avatar
Erik Nygren committed
174

175
176
        self.remove_agents_at_target = remove_agents_at_target

177
        self.obs_builder = obs_builder_object
u229589's avatar
u229589 committed
178
        self.obs_builder.set_env(self)
179

180
        self._max_episode_steps: Optional[int] = None
spiglerg's avatar
spiglerg committed
181
182
        self._elapsed_steps = 0

183
184
        self.obs_dict = {}
        self.rewards_dict = {}
Erik Nygren's avatar
Erik Nygren committed
185
        self.dev_obs_dict = {}
Erik Nygren's avatar
Erik Nygren committed
186
        self.dev_pred_dict = {}
187

u229589's avatar
u229589 committed
188
        self.agents: List[EnvAgent] = []
189
        self.num_resets = 0
190
        self.distance_map = DistanceMap(self.agents, self.height, self.width)
191

Erik Nygren's avatar
Erik Nygren committed
192
        self.action_space = [5]
193

194
        self._seed()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
195
        if random_seed:
196
197
            self._seed(seed=random_seed)

198
        self.agent_positions = None
199

200
201
        # save episode timesteps ie agent positions, orientations.  (not yet actions / observations)
        self.record_steps = record_steps  # whether to save timesteps
202
        # save timesteps in here: [[[row, col, dir, malfunction],...nAgents], ...nSteps]
203
204
        self.cur_episode = []
        self.list_actions = []  # save actions in here
205

206
207
        self.motionCheck = ac.MotionCheck()

208
209
    def _seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
210
        random.seed(seed)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
211
212
213
214
215
216
217
218
        self.random_seed = seed

        # Keep track of all the seeds in order
        if not hasattr(self, 'seed_history'):
            self.seed_history = [seed]
        if self.seed_history[-1] != seed:
            self.seed_history.append(seed)

219
220
        return [seed]

221
    # no more agent_handles
222
    def get_agent_handles(self):
223
        return range(self.get_num_agents())
224
    
u229589's avatar
u229589 committed
225
226
    def get_num_agents(self) -> int:
        return len(self.agents)
227

u229589's avatar
u229589 committed
228
    def add_agent(self, agent):
hagrid67's avatar
hagrid67 committed
229
230
231
        """ Add static info for a single agent.
            Returns the index of the new agent.
        """
u229589's avatar
u229589 committed
232
233
        self.agents.append(agent)
        return len(self.agents) - 1
hagrid67's avatar
hagrid67 committed
234

Erik Nygren's avatar
Erik Nygren committed
235
    def reset_agents(self):
u229589's avatar
u229589 committed
236
        """ Reset the agents to their starting positions
hagrid67's avatar
hagrid67 committed
237
        """
u229589's avatar
u229589 committed
238
239
        for agent in self.agents:
            agent.reset()
240
        self.active_agents = [i for i in range(len(self.agents))]
Erik Nygren's avatar
Erik Nygren committed
241

Erik Nygren's avatar
Erik Nygren committed
242
243
244
245
246
247
248
249
250
251
252
253
254
255
    def action_required(self, agent):
        """
        Check if an agent needs to provide an action

        Parameters
        ----------
        agent: RailEnvAgent
        Agent we want to check

        Returns
        -------
        True: Agent needs to provide an action
        False: Agent cannot provide an action
        """
256
        return agent.state == TrainState.READY_TO_DEPART or \
Dipam Chakraborty's avatar
Dipam Chakraborty committed
257
               ( agent.state.is_on_map_state() and agent.speed_counter.is_cell_entry )
Erik Nygren's avatar
Erik Nygren committed
258

259
    def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, *,
Dipam Chakraborty's avatar
Dipam Chakraborty committed
260
              random_seed: int = None) -> Tuple[Dict, Dict]:
261
262
263
264
265
266
267
268
269
270
271
        """
        reset(regenerate_rail, regenerate_schedule, activate_agents, random_seed)

        The method resets the rail environment

        Parameters
        ----------
        regenerate_rail : bool, optional
            regenerate the rails
        regenerate_schedule : bool, optional
            regenerate the schedule and the static agents
Dipam Chakraborty's avatar
Dipam Chakraborty committed
272
        random_seed : int, optional
273
274
275
276
277
278
279
280
            random seed for environment

        Returns
        -------
        observation_dict: Dict
            Dictionary with an observation for each agent
        info_dict: Dict with agent specific information

hagrid67's avatar
hagrid67 committed
281
        """
282

283
284
        if random_seed:
            self._seed(random_seed)
285

286
        optionals = {}
287
        if regenerate_rail or self.rail is None:
288
289
290
291
292
293
294
295
296
297

            if "__call__" in dir(self.rail_generator):
                rail, optionals = self.rail_generator(
                    self.width, self.height, self.number_of_agents, self.num_resets, self.np_random)
            elif "generate" in dir(self.rail_generator):
                rail, optionals = self.rail_generator.generate(
                    self.width, self.height, self.number_of_agents, self.num_resets, self.np_random)
            else:
                raise ValueError("Could not invoke __call__ or generate on rail_generator")

u214892's avatar
u214892 committed
298
            self.rail = rail
299
            self.height, self.width = self.rail.grid.shape
300

spmohanty's avatar
spmohanty committed
301
            # Do a new set_env call on the obs_builder to ensure
302
            # that obs_builder specific instantiations are made according to the
spmohanty's avatar
spmohanty committed
303
304
305
            # specifications of the current environment : like width, height, etc
            self.obs_builder.set_env(self)

306
        if optionals and 'distance_map' in optionals:
307
            self.distance_map.set(optionals['distance_map'])
308

309
        if regenerate_schedule or regenerate_rail or self.get_num_agents() == 0:
u214892's avatar
u214892 committed
310
311
312
            agents_hints = None
            if optionals and 'agents_hints' in optionals:
                agents_hints = optionals['agents_hints']
313

314
            line = self.line_generator(self.rail, self.number_of_agents, agents_hints, 
315
                                               self.num_resets, self.np_random)
316
            self.agents = EnvAgent.from_line(line)
317

318
319
            # Reset distance map - basically initializing
            self.distance_map.reset(self.agents, self.rail)
320

321
            # NEW : Time Schedule Generation
322
            timetable = timetable_generator(self.agents, self.distance_map, 
323
324
                                               agents_hints, self.np_random)

325
            self._max_episode_steps = timetable.max_episode_steps
326
327

            for agent_i, agent in enumerate(self.agents):
328
329
                agent.earliest_departure = timetable.earliest_departures[agent_i]         
                agent.latest_arrival = timetable.latest_arrivals[agent_i]
330
331
        else:
            self.distance_map.reset(self.agents, self.rail)
332
333
334
        
        # Reset agents to initial states
        self.reset_agents()
hagrid67's avatar
hagrid67 committed
335

336
        self.num_resets += 1
spiglerg's avatar
spiglerg committed
337
        self._elapsed_steps = 0
338

339
340
341
342
        # Agent positions map
        self.agent_positions = np.zeros((self.height, self.width), dtype=int) - 1
        self._update_agent_positions_map(ignore_old_positions=False)

343
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
344

345
346
347
        # Reset the state of the observation builder with the new environment
        self.obs_builder.reset()

348
349
350
        # Empty the episode store of agent positions
        self.cur_episode = []

351
        info_dict = self.get_info_dict()
352
        # Return the new observation vectors for each agent
353
        observation_dict: Dict = self._get_observations()
354
355
        if hasattr(self, "renderer") and self.renderer is not None:
            self.renderer = None
356
        return observation_dict, info_dict
357
358
359
360
361
362


    def _update_agent_positions_map(self, ignore_old_positions=True):
        """ Update the agent_positions array for agents that changed positions """
        for agent in self.agents:
            if not ignore_old_positions or agent.old_position != agent.position:
Dipam Chakraborty's avatar
Dipam Chakraborty committed
363
364
                if agent.position is not None:
                    self.agent_positions[agent.position] = agent.handle
365
366
                if agent.old_position is not None:
                    self.agent_positions[agent.old_position] = -1
367
368
    
    def generate_state_transition_signals(self, agent, preprocessed_action, movement_allowed):
369
        """ Generate State Transitions Signals used in the state machine """
370
        st_signals = StateTransitionSignals()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
371
        
372
        # Malfunction starts when in_malfunction is set to true
373
        st_signals.in_malfunction = agent.malfunction_handler.in_malfunction
374

375
376
        # Malfunction counter complete - Malfunction ends next timestep
        st_signals.malfunction_counter_complete = agent.malfunction_handler.malfunction_counter_complete
377

378
379
        # Earliest departure reached - Train is allowed to move now
        st_signals.earliest_departure_reached = self._elapsed_steps >= agent.earliest_departure
380

381
382
        # Stop Action Given
        st_signals.stop_action_given = (preprocessed_action == RailEnvActions.STOP_MOVING)
383

384
        # Valid Movement action Given
Dipam Chakraborty's avatar
Dipam Chakraborty committed
385
        st_signals.valid_movement_action_given = preprocessed_action.is_moving_action() and movement_allowed
386

387
        # Target Reached
388
        st_signals.target_reached = env_utils.fast_position_equal(agent.position, agent.target)
389

390
        # Movement conflict - Multiple trains trying to move into same cell
Dipam Chakraborty's avatar
Dipam Chakraborty committed
391
392
        # If speed counter is not in cell exit, the train can enter the cell
        st_signals.movement_conflict = (not movement_allowed) and agent.speed_counter.is_cell_exit
Erik Nygren's avatar
Erik Nygren committed
393

394
        return st_signals
u214892's avatar
u214892 committed
395

396
397
398
399
400
401
402
403
404
405
    def _handle_end_reward(self, agent: EnvAgent) -> int:
        '''
        Handles end-of-episode reward for a particular agent.

        Parameters
        ----------
        agent : EnvAgent
        '''
        reward = None
        # agent done? (arrival_time is not None)
406
        if agent.state == TrainState.DONE:
407
408
409
410
411
412
413
            # if agent arrived earlier or on time = 0
            # if agent arrived later = -ve reward based on how late
            reward = min(agent.latest_arrival - agent.arrival_time, 0)

        # Agents not done (arrival_time is None)
        else:
            # CANCELLED check (never departed)
414
            if (agent.state.is_off_map_state()):
415
                reward = -1 * self.cancellation_factor * \
416
                    (agent.get_travel_time_on_shortest_path(self.distance_map) + self.cancellation_time_buffer)
417
418

            # Departed but never reached
419
            if (agent.state.is_on_map_state()):
420
421
422
423
                reward = agent.get_current_delay(self._elapsed_steps, self.distance_map)
        
        return reward

Dipam Chakraborty's avatar
Dipam Chakraborty committed
424
    def preprocess_action(self, action, agent):
425
        """
Dipam Chakraborty's avatar
Dipam Chakraborty committed
426
427
428
429
        Preprocess the provided action
            * Change to DO_NOTHING if illegal action
            * Block all actions when in waiting state
            * Check MOVE_LEFT/MOVE_RIGHT actions on current position else try MOVE_FORWARD
430
        """
431
        action = action_preprocessing.preprocess_raw_action(action, agent.state, agent.action_saver.saved_action)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
432
        action = action_preprocessing.preprocess_action_when_waiting(action, agent.state)
433

Dipam Chakraborty's avatar
Dipam Chakraborty committed
434
435
436
437
438
439
440
        # Try moving actions on current position
        current_position, current_direction = agent.position, agent.direction
        if current_position is None: # Agent not added on map yet
            current_position, current_direction = agent.initial_position, agent.initial_direction
        
        action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction)
        return action
441
442
    
    def clear_rewards_dict(self):
443
444
        """ Reset the rewards dictionary """
        self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))}
445

446
447
448
449
450
451
452
453
    def get_info_dict(self):
        """ 
        Returns dictionary of infos for all agents 
        dict_keys : action_required - 
                    malfunction - Counter value for malfunction > 0 means train is in malfunction
                    speed - Speed of the train
                    state - State from the trains's state machine
        """
454
        info_dict = {
455
456
            'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)},
            'malfunction': {
457
                i: agent.malfunction_handler.malfunction_down_counter for i, agent in enumerate(self.agents)
458
459
460
            },
            'speed': {i: agent.speed_counter.speed for i, agent in enumerate(self.agents)},
            'state': {i: agent.state for i, agent in enumerate(self.agents)}
461
        }
462
        return info_dict
463
464
    
    def update_step_rewards(self, i_agent):
465
466
467
        """
        Update the rewards dict for agent id i_agent for every timestep
        """
468
        pass
469

470
    def end_of_episode_update(self, have_all_agents_ended):
471
472
473
474
        """ 
        Updates made when episode ends
        Parameters: have_all_agents_ended - Indicates if all agents have reached done state
        """
475
476
        if have_all_agents_ended or \
           ( (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps)):
477

478
            for i_agent, agent in enumerate(self.agents):
479
480
481
482
483
                
                reward = self._handle_end_reward(agent)
                self.rewards_dict[i_agent] += reward
                
                self.dones[i_agent] = True
spiglerg's avatar
spiglerg committed
484

485
            self.dones["__all__"] = True
486

487
    def handle_done_state(self, agent):
488
        """ Any updates to agent to be made in Done state """
489
490
491
492
        if agent.state == TrainState.DONE:
            agent.arrival_time = self._elapsed_steps
            if self.remove_agents_at_target:
                agent.position = None
493

494
495
496
497
498
    def step(self, action_dict_: Dict[int, RailEnvActions]):
        """
        Updates rewards for the agents at a step.
        """
        self._elapsed_steps += 1
499

500
501
502
        # Not allowed to step further once done
        if self.dones["__all__"]:
            raise Exception("Episode is done, cannot call step()")
503

504
        self.clear_rewards_dict()
505

506
        have_all_agents_ended = True # Boolean flag to check if all agents are done
507

508
        self.motionCheck = ac.MotionCheck()  # reset the motion check
509

510
        temp_transition_data = {}
511
        
Dipam Chakraborty's avatar
Dipam Chakraborty committed
512
513
        for agent in self.agents:
            i_agent = agent.handle
Dipam Chakraborty's avatar
Dipam Chakraborty committed
514
515
            agent.old_position = agent.position
            agent.old_direction = agent.direction
Dipam Chakraborty's avatar
Dipam Chakraborty committed
516
517
            # Generate malfunction
            agent.malfunction_handler.generate_malfunction(self.malfunction_generator, self.np_random)
518

519
            # Get action for the agent
520
            action = action_dict_.get(i_agent, RailEnvActions.DO_NOTHING)
521

Dipam Chakraborty's avatar
Dipam Chakraborty committed
522
            preprocessed_action = self.preprocess_action(action, agent)
523

524
            # Save moving actions in not already saved
Dipam Chakraborty's avatar
Dipam Chakraborty committed
525
            agent.action_saver.save_action_if_allowed(preprocessed_action, agent.state)
526

527
528
            # Train's next position can change if current stopped in a fractional speed or train is at cell's exit
            position_update_allowed = (agent.speed_counter.is_cell_exit or agent.state == TrainState.STOPPED)
529

530
531
            # Calculate new position
            # Add agent to the map if not on it yet
Dipam Chakraborty's avatar
Dipam Chakraborty committed
532
            if agent.position is None and agent.action_saver.is_action_saved:
533
534
                new_position = agent.initial_position
                new_direction = agent.initial_direction
535
                
536
            # If movement is allowed apply saved action independent of other agents
537
            elif agent.action_saver.is_action_saved and position_update_allowed:
Dipam Chakraborty's avatar
Dipam Chakraborty committed
538
                saved_action = agent.action_saver.saved_action
539
                # Apply action independent of other agents and get temporary new position and direction
540
                new_position, new_direction  = env_utils.apply_action_independent(saved_action, 
Dipam Chakraborty's avatar
Dipam Chakraborty committed
541
542
543
                                                                             self.rail, 
                                                                             agent.position, 
                                                                             agent.direction)
544
                preprocessed_action = saved_action
Dipam Chakraborty's avatar
Dipam Chakraborty committed
545
            else:
546
                new_position, new_direction = agent.position, agent.direction
547

548
            temp_transition_data[i_agent] = env_utils.AgentTransitionData(position=new_position,
549
550
                                                                direction=new_direction,
                                                                preprocessed_action=preprocessed_action)
551
            
552
            # This is for storing and later checking for conflicts of agents trying to occupy same cell                                                    
553
            self.motionCheck.addAgent(i_agent, agent.position, new_position)
554

555
        # Find conflicts between trains trying to occupy same cell
556
        self.motionCheck.find_conflicts()
557
        
Dipam Chakraborty's avatar
Dipam Chakraborty committed
558
559
        for agent in self.agents:
            i_agent = agent.handle
560
            agent_transition_data = temp_transition_data[i_agent]
561

562
            ## Update positions
Dipam Chakraborty's avatar
Dipam Chakraborty committed
563
564
            if agent.malfunction_handler.in_malfunction:
                movement_allowed = False
565
            else:
566
                movement_allowed = self.motionCheck.check_motion(i_agent, agent.position)
567

568
569
570
571
            # Position can be changed only if other cell is empty
            # And either the speed counter completes or agent is being added to map
            if movement_allowed and \
               (agent.speed_counter.is_cell_exit or agent.position is None):
572
573
                agent.position = agent_transition_data.position
                agent.direction = agent_transition_data.direction
574

575
            preprocessed_action = agent_transition_data.preprocessed_action
576

577
            ## Update states
578
579
580
            state_transition_signals = self.generate_state_transition_signals(agent, preprocessed_action, movement_allowed)
            agent.state_machine.set_transition_signals(state_transition_signals)
            agent.state_machine.step()
581

582
            # Off map or on map state and position should match
583
            env_utils.state_position_sync_check(agent.state, agent.position, agent.handle)
584

585
586
            # Handle done state actions, optionally remove agents
            self.handle_done_state(agent)
587
588
            
            have_all_agents_ended &= (agent.state == TrainState.DONE)
589

590
            ## Update rewards
591
            self.update_step_rewards(i_agent)
592

593
            ## Update counters (malfunction and speed)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
594
            agent.speed_counter.update_counter(agent.state, agent.old_position)
595
                                            #    agent.state_machine.previous_state)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
596
            agent.malfunction_handler.update_counter()
597

598
            # Clear old action when starting in new cell
599
            if agent.speed_counter.is_cell_entry and agent.position is not None:
600
601
                agent.action_saver.clear_saved_action()
        
602
603
        # Check if episode has ended and update rewards and dones
        self.end_of_episode_update(have_all_agents_ended)
u214892's avatar
u214892 committed
604

Dipam Chakraborty's avatar
Dipam Chakraborty committed
605
        self._update_agent_positions_map()
606

607
        return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict() 
spiglerg's avatar
spiglerg committed
608

609
    def record_timestep(self, dActions):
610
611
612
        """ 
        Record the positions and orientations of all agents in memory, in the cur_episode
        """
613
614
615
616
617
618
619
620
621
622
        list_agents_state = []
        for i_agent in range(self.get_num_agents()):
            agent = self.agents[i_agent]
            # the int cast is to avoid numpy types which may cause problems with msgpack
            # in env v2, agents may have position None, before starting
            if agent.position is None:
                pos = (0, 0)
            else:
                pos = (int(agent.position[0]), int(agent.position[1]))
            # print("pos:", pos, type(pos[0]))
623
624
            list_agents_state.append([
                    *pos, int(agent.direction), 
625
                    agent.malfunction_handler.malfunction_down_counter,  
626
627
                    int(agent.status),
                    int(agent.position in self.motionCheck.svDeadlocked)
628
                    ])
629

630
        self.cur_episode.append(list_agents_state)
631
        self.list_actions.append(dActions)
632

633
    def _get_observations(self):
634
        """
635
        Utility which returns the dictionary of observations for an agent with respect to environment
636
        """
637
        # print(f"_get_obs - num agents: {self.get_num_agents()} {list(range(self.get_num_agents()))}")
638
        self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
639
        return self.obs_dict
640

u214892's avatar
u214892 committed
641
    def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]:
642
643
644
        """
        Returns directions in which the agent can move
        """
645
        return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col))
u214892's avatar
u214892 committed
646

647
    def _exp_distirbution_synced(self, rate: float) -> float:
648
649
650
651
652
653
654
655
656
        """
        Generates sample from exponential distribution
        We need this to guarantee synchronity between different instances with same seed.
        :param rate:
        :return:
        """
        u = self.np_random.rand()
        x = - np.log(1 - u) * rate
        return x
657

658
    def _is_agent_ok(self, agent: EnvAgent) -> bool:
Erik Nygren's avatar
Erik Nygren committed
659
660
661
662
663
664
665
666
667
668
669
        """
        Check if an agent is ok, meaning it can move and is not malfuncitoinig
        Parameters
        ----------
        agent

        Returns
        -------
        True if agent is ok, False otherwise

        """
Dipam Chakraborty's avatar
Dipam Chakraborty committed
670
        return agent.malfunction_handler.in_malfunction
671
        
hagrid67's avatar
hagrid67 committed
672
673

    def save(self, filename):
674
        print("DEPRECATED call to env.save() - pls call RailEnvPersister.save()")
hagrid67's avatar
hagrid67 committed
675
        persistence.RailEnvPersister.save(self, filename)
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749

    def render(self, mode="rgb_array", gl="PGL", agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
            show_debug=False, clear_debug_text=True, show=False,
            screen_height=600, screen_width=800,
            show_observations=False, show_predictions=False,
            show_rowcols=False, return_image=True):
        """
        This methods provides the option to render the
        environment's behavior as an image or to a window.
        Parameters
        ----------
        mode

        Returns
        -------
        Image if mode is rgb_array, opens a window otherwise
        """
        if not hasattr(self, "renderer") or self.renderer is None:
            self.initialize_renderer(mode=mode, gl=gl,  # gl="TKPILSVG",
                                    agent_render_variant=agent_render_variant,
                                    show_debug=show_debug,
                                    clear_debug_text=clear_debug_text,
                                    show=show,
                                    screen_height=screen_height,  # Adjust these parameters to fit your resolution
                                    screen_width=screen_width)
        return self.update_renderer(mode=mode, show=show, show_observations=show_observations,
                                    show_predictions=show_predictions,
                                    show_rowcols=show_rowcols, return_image=return_image)

    def initialize_renderer(self, mode, gl,
                agent_render_variant,
                show_debug,
                clear_debug_text,
                show,
                screen_height,
                screen_width):
        # Initiate the renderer
        self.renderer = RenderTool(self, gl=gl,  # gl="TKPILSVG",
                                agent_render_variant=agent_render_variant,
                                show_debug=show_debug,
                                clear_debug_text=clear_debug_text,
                                screen_height=screen_height,  # Adjust these parameters to fit your resolution
                                screen_width=screen_width)  # Adjust these parameters to fit your resolution
        self.renderer.show = show
        self.renderer.reset()

    def update_renderer(self, mode, show, show_observations, show_predictions,
                    show_rowcols, return_image):
        """
        This method updates the render.
        Parameters
        ----------
        mode

        Returns
        -------
        Image if mode is rgb_array, None otherwise
        """
        image = self.renderer.render_env(show=show, show_observations=show_observations,
                                show_predictions=show_predictions,
                                show_rowcols=show_rowcols, return_image=return_image)
        if mode == 'rgb_array':
            return image[:, :, :3]

    def close(self):
        """
        This methods closes any renderer window.
        """
        if hasattr(self, "renderer") and self.renderer is not None:
            try:
                if self.renderer.show:
                    self.renderer.close_window()
            except Exception as e:
                print("Could Not close window due to:",e)
750
            self.renderer = None