rail_env.py 31.9 KB
Newer Older
1
"""
2
Definition of the RailEnv environment.
3
"""
4
import random
5

6
from typing import List, Optional, Dict, Tuple
7

8
9
import numpy as np
from gym.utils import seeding
10

Dipam Chakraborty's avatar
Dipam Chakraborty committed
11
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
12
from flatland.core.env import Environment
13
from flatland.core.env_observation_builder import ObservationBuilder
14
from flatland.core.grid.grid4 import Grid4Transitions
u214892's avatar
u214892 committed
15
from flatland.core.transition_map import GridTransitionMap
16
from flatland.envs.agent_utils import EnvAgent
17
from flatland.envs.distance_map import DistanceMap
18
from flatland.envs.rail_env_action import RailEnvActions
hagrid67's avatar
hagrid67 committed
19

20
21
from flatland.envs import malfunction_generators as mal_gen
from flatland.envs import rail_generators as rail_gen
22
from flatland.envs import line_generators as line_gen
23
from flatland.envs.timetable_generators import timetable_generator
hagrid67's avatar
hagrid67 committed
24
from flatland.envs import persistence
25
from flatland.envs import agent_chains as ac
hagrid67's avatar
hagrid67 committed
26

27
28
from flatland.envs.observations import GlobalObsForRailEnv

29
from flatland.envs.timetable_generators import timetable_generator
30
from flatland.envs.step_utils.states import TrainState, StateTransitionSignals
31
from flatland.envs.step_utils.transition_utils import check_valid_action
32
from flatland.envs.step_utils import action_preprocessing
33
from flatland.envs.step_utils import env_utils
u214892's avatar
u214892 committed
34

35
36
37
38
39
40
41
42
43
44
class RailEnv(Environment):
    """
    RailEnv environment class.

    RailEnv is an environment inspired by a (simplified version of) a rail
    network, in which agents (trains) have to navigate to their target
    locations in the shortest time possible, while at the same time cooperating
    to avoid bottlenecks.

    The valid actions in the environment are:
45
46
47
48
49
50

     -   0: do nothing (continue moving or stay still)
     -   1: turn left at switch and move to the next cell; if the agent was not moving, movement is started
     -   2: move to the next cell in front of the agent; if the agent was not moving, movement is started
     -   3: turn right at switch and move to the next cell; if the agent was not moving, movement is started
     -   4: stop moving
51
52
53
54

    Moving forward in a dead-end cell makes the agent turn 180 degrees and step
    to the cell it came from.

55

56
57
58
    The actions of the agents are executed in order of their handle to prevent
    deadlocks and to allow them to learn relative priorities.

59
60
61
62
63
64
65
66
67
68
69
70
    Reward Function:

    It costs each agent a step_penalty for every time-step taken in the environment. Independent of the movement
    of the agent. Currently all other penalties such as penalty for stopping, starting and invalid actions are set to 0.

    alpha = 1
    beta = 1
    Reward function parameters:

    - invalid_action_penalty = 0
    - step_penalty = -alpha
    - global_reward = beta
71
    - epsilon = avoid rounding errors
72
73
74
    - stop_penalty = 0  # penalty for stopping a moving agent
    - start_penalty = 0  # penalty for starting a stopped agent

75
76
    Stochastic malfunctioning of trains:
    Trains in RailEnv can malfunction if they are halted too often (either by their own choice or because an invalid
spiglerg's avatar
spiglerg committed
77
78
    action or cell is selected.

79
80
81
    Every time an agent stops, an agent has a certain probability of malfunctioning. Malfunctions of trains follow a
    poisson process with a certain rate. Not all trains will be affected by malfunctions during episodes to keep
    complexity managable.
spiglerg's avatar
spiglerg committed
82
83
84
85

    TODO: currently, the parameters that control the stochasticity of the environment are hard-coded in init().
    For Round 2, they will be passed to the constructor as arguments, to allow for more flexibility.

86
    """
u214892's avatar
u214892 committed
87
88
    # Epsilon to avoid rounding errors
    epsilon = 0.01
89
90
91
    # NEW : REW: Sparse Reward
    alpha = 0
    beta = 0
u214892's avatar
u214892 committed
92
93
    step_penalty = -1 * alpha
    global_reward = 1 * beta
94
    invalid_action_penalty = 0  # previously -2; GIACOMO: we decided that invalid actions will carry no penalty
u214892's avatar
u214892 committed
95
96
    stop_penalty = 0  # penalty for stopping a moving agent
    start_penalty = 0  # penalty for starting a stopped agent
97
98
    cancellation_factor = 1
    cancellation_time_buffer = 0
99
100
101
102

    def __init__(self,
                 width,
                 height,
103
                 rail_generator=None,
104
                 line_generator=None,  # : line_gen.LineGenerator = line_gen.random_line_generator(),
105
                 number_of_agents=2,
Erik Nygren's avatar
Erik Nygren committed
106
                 obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(),
Erik Nygren's avatar
Erik Nygren committed
107
                 malfunction_generator_and_process_data=None,  # mal_gen.no_malfunction_generator(),
108
                 malfunction_generator=None,
109
                 remove_agents_at_target=True,
Dipam Chakraborty's avatar
Dipam Chakraborty committed
110
                 random_seed=None,
111
                 record_steps=False,
u214892's avatar
u214892 committed
112
                 ):
113
114
115
116
        """
        Environment init.

        Parameters
117
        ----------
118
        rail_generator : function
119
120
121
122
            The rail_generator function is a function that takes the width,
            height and agents handles of a  rail environment, along with the number of times
            the env has been reset, and returns a GridTransitionMap object and a list of
            starting positions, targets, and initial orientations for agent handle.
123
            The rail_generator can pass a distance map in the hints or information for specific line_generators.
u214892's avatar
u214892 committed
124
            Implementations can be found in flatland/envs/rail_generators.py
125
126
        line_generator : function
            The line_generator function is a function that takes the grid, the number of agents and optional hints
u214892's avatar
u214892 committed
127
            and returns a list of starting positions, targets, initial orientations and speed for all agent handles.
128
            Implementations can be found in flatland/envs/line_generators.py
129
130
131
132
133
134
135
136
137
138
139
140
        width : int
            The width of the rail map. Potentially in the future,
            a range of widths to sample from.
        height : int
            The height of the rail map. Potentially in the future,
            a range of heights to sample from.
        number_of_agents : int
            Number of agents to spawn on the map. Potentially in the future,
            a range of number of agents to sample from.
        obs_builder_object: ObservationBuilder object
            ObservationBuilder-derived object that takes builds observation
            vectors for each agent.
141
142
143
        remove_agents_at_target : bool
            If remove_agents_at_target is set to true then the agents will be removed by placing to
            RailEnv.DEPOT_POSITION when the agent has reach it's target position.
144
145
146
        random_seed : int or None
            if None, then its ignored, else the random generators are seeded with this number to ensure
            that stochastic operations are replicable across multiple operations
147
        """
148
        super().__init__()
149

150
151
152
153
154
155
        if malfunction_generator_and_process_data is not None:
            print("DEPRECATED - RailEnv arg: malfunction_and_process_data - use malfunction_generator")
            self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
        elif malfunction_generator is not None:
            self.malfunction_generator = malfunction_generator
            # malfunction_process_data is not used
Erik Nygren's avatar
Erik Nygren committed
156
            # self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
157
158
159
160
161
            self.malfunction_process_data = self.malfunction_generator.get_process_data()
        # replace default values here because we can't use default args values because of cyclic imports
        else:
            self.malfunction_generator = mal_gen.NoMalfunctionGen()
            self.malfunction_process_data = self.malfunction_generator.get_process_data()
162
        
163
        self.number_of_agents = number_of_agents
164

165
        if rail_generator is None:
166
            rail_generator = rail_gen.sparse_rail_generator()
167
        self.rail_generator = rail_generator
168
        if line_generator is None:
169
            line_generator = line_gen.sparse_line_generator()
170
        self.line_generator = line_generator
171

u214892's avatar
u214892 committed
172
        self.rail: Optional[GridTransitionMap] = None
173
174
        self.width = width
        self.height = height
Erik Nygren's avatar
Erik Nygren committed
175

176
177
        self.remove_agents_at_target = remove_agents_at_target

178
        self.obs_builder = obs_builder_object
u229589's avatar
u229589 committed
179
        self.obs_builder.set_env(self)
180

181
        self._max_episode_steps: Optional[int] = None
spiglerg's avatar
spiglerg committed
182
183
        self._elapsed_steps = 0

184
185
        self.obs_dict = {}
        self.rewards_dict = {}
Erik Nygren's avatar
Erik Nygren committed
186
        self.dev_obs_dict = {}
Erik Nygren's avatar
Erik Nygren committed
187
        self.dev_pred_dict = {}
188

u229589's avatar
u229589 committed
189
        self.agents: List[EnvAgent] = []
190
        self.num_resets = 0
191
        self.distance_map = DistanceMap(self.agents, self.height, self.width)
192

Erik Nygren's avatar
Erik Nygren committed
193
        self.action_space = [5]
194

195
        self._seed()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
196
        if random_seed:
197
198
            self._seed(seed=random_seed)

199
        self.agent_positions = None
200

201
202
        # save episode timesteps ie agent positions, orientations.  (not yet actions / observations)
        self.record_steps = record_steps  # whether to save timesteps
203
        # save timesteps in here: [[[row, col, dir, malfunction],...nAgents], ...nSteps]
204
205
        self.cur_episode = []
        self.list_actions = []  # save actions in here
206

207
208
        self.motionCheck = ac.MotionCheck()

209
210
    def _seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
211
        random.seed(seed)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
212
213
214
215
216
217
218
219
        self.random_seed = seed

        # Keep track of all the seeds in order
        if not hasattr(self, 'seed_history'):
            self.seed_history = [seed]
        if self.seed_history[-1] != seed:
            self.seed_history.append(seed)

220
221
        return [seed]

222
    # no more agent_handles
223
    def get_agent_handles(self):
224
        return range(self.get_num_agents())
225
    
u229589's avatar
u229589 committed
226
227
    def get_num_agents(self) -> int:
        return len(self.agents)
228

u229589's avatar
u229589 committed
229
    def add_agent(self, agent):
hagrid67's avatar
hagrid67 committed
230
231
232
        """ Add static info for a single agent.
            Returns the index of the new agent.
        """
u229589's avatar
u229589 committed
233
234
        self.agents.append(agent)
        return len(self.agents) - 1
hagrid67's avatar
hagrid67 committed
235

Erik Nygren's avatar
Erik Nygren committed
236
    def reset_agents(self):
u229589's avatar
u229589 committed
237
        """ Reset the agents to their starting positions
hagrid67's avatar
hagrid67 committed
238
        """
u229589's avatar
u229589 committed
239
240
        for agent in self.agents:
            agent.reset()
241
        self.active_agents = [i for i in range(len(self.agents))]
Erik Nygren's avatar
Erik Nygren committed
242

Erik Nygren's avatar
Erik Nygren committed
243
244
245
246
247
248
249
250
251
252
253
254
255
256
    def action_required(self, agent):
        """
        Check if an agent needs to provide an action

        Parameters
        ----------
        agent: RailEnvAgent
        Agent we want to check

        Returns
        -------
        True: Agent needs to provide an action
        False: Agent cannot provide an action
        """
257
        return agent.state == TrainState.READY_TO_DEPART or \
Dipam Chakraborty's avatar
Dipam Chakraborty committed
258
               ( agent.state.is_on_map_state() and agent.speed_counter.is_cell_entry )
Erik Nygren's avatar
Erik Nygren committed
259

260
    def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, *,
Dipam Chakraborty's avatar
Dipam Chakraborty committed
261
              random_seed: int = None) -> Tuple[Dict, Dict]:
262
263
264
265
266
267
268
269
270
271
272
        """
        reset(regenerate_rail, regenerate_schedule, activate_agents, random_seed)

        The method resets the rail environment

        Parameters
        ----------
        regenerate_rail : bool, optional
            regenerate the rails
        regenerate_schedule : bool, optional
            regenerate the schedule and the static agents
Dipam Chakraborty's avatar
Dipam Chakraborty committed
273
        random_seed : int, optional
274
275
276
277
278
279
280
281
            random seed for environment

        Returns
        -------
        observation_dict: Dict
            Dictionary with an observation for each agent
        info_dict: Dict with agent specific information

hagrid67's avatar
hagrid67 committed
282
        """
283

284
285
        if random_seed:
            self._seed(random_seed)
286

287
        optionals = {}
288
        if regenerate_rail or self.rail is None:
289
290
291
292
293
294
295
296
297
298

            if "__call__" in dir(self.rail_generator):
                rail, optionals = self.rail_generator(
                    self.width, self.height, self.number_of_agents, self.num_resets, self.np_random)
            elif "generate" in dir(self.rail_generator):
                rail, optionals = self.rail_generator.generate(
                    self.width, self.height, self.number_of_agents, self.num_resets, self.np_random)
            else:
                raise ValueError("Could not invoke __call__ or generate on rail_generator")

u214892's avatar
u214892 committed
299
            self.rail = rail
300
            self.height, self.width = self.rail.grid.shape
301

spmohanty's avatar
spmohanty committed
302
            # Do a new set_env call on the obs_builder to ensure
303
            # that obs_builder specific instantiations are made according to the
spmohanty's avatar
spmohanty committed
304
305
306
            # specifications of the current environment : like width, height, etc
            self.obs_builder.set_env(self)

307
        if optionals and 'distance_map' in optionals:
308
            self.distance_map.set(optionals['distance_map'])
309

310
        if regenerate_schedule or regenerate_rail or self.get_num_agents() == 0:
u214892's avatar
u214892 committed
311
312
313
            agents_hints = None
            if optionals and 'agents_hints' in optionals:
                agents_hints = optionals['agents_hints']
314

315
            line = self.line_generator(self.rail, self.number_of_agents, agents_hints, 
316
                                               self.num_resets, self.np_random)
317
            self.agents = EnvAgent.from_line(line)
318

319
320
            # Reset distance map - basically initializing
            self.distance_map.reset(self.agents, self.rail)
321

322
            # NEW : Time Schedule Generation
323
            timetable = timetable_generator(self.agents, self.distance_map, 
324
325
                                               agents_hints, self.np_random)

326
            self._max_episode_steps = timetable.max_episode_steps
327
328

            for agent_i, agent in enumerate(self.agents):
329
330
                agent.earliest_departure = timetable.earliest_departures[agent_i]         
                agent.latest_arrival = timetable.latest_arrivals[agent_i]
331
332
        else:
            self.distance_map.reset(self.agents, self.rail)
333
334
335
        
        # Reset agents to initial states
        self.reset_agents()
hagrid67's avatar
hagrid67 committed
336

337
        self.num_resets += 1
spiglerg's avatar
spiglerg committed
338
        self._elapsed_steps = 0
339

340
341
342
343
        # Agent positions map
        self.agent_positions = np.zeros((self.height, self.width), dtype=int) - 1
        self._update_agent_positions_map(ignore_old_positions=False)

344
        self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
345

346
347
348
        # Reset the state of the observation builder with the new environment
        self.obs_builder.reset()

349
350
351
        # Empty the episode store of agent positions
        self.cur_episode = []

352
        info_dict = self.get_info_dict()
353
        # Return the new observation vectors for each agent
354
        observation_dict: Dict = self._get_observations()
355
356
        if hasattr(self, "renderer") and self.renderer is not None:
            self.renderer = None
357
        return observation_dict, info_dict
358
359
360
361
362
363


    def _update_agent_positions_map(self, ignore_old_positions=True):
        """ Update the agent_positions array for agents that changed positions """
        for agent in self.agents:
            if not ignore_old_positions or agent.old_position != agent.position:
Dipam Chakraborty's avatar
Dipam Chakraborty committed
364
365
                if agent.position is not None:
                    self.agent_positions[agent.position] = agent.handle
366
367
                if agent.old_position is not None:
                    self.agent_positions[agent.old_position] = -1
368
369
    
    def generate_state_transition_signals(self, agent, preprocessed_action, movement_allowed):
370
        """ Generate State Transitions Signals used in the state machine """
371
        st_signals = StateTransitionSignals()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
372
        
373
        # Malfunction starts when in_malfunction is set to true
374
        st_signals.in_malfunction = agent.malfunction_handler.in_malfunction
375

376
377
        # Malfunction counter complete - Malfunction ends next timestep
        st_signals.malfunction_counter_complete = agent.malfunction_handler.malfunction_counter_complete
378

379
380
        # Earliest departure reached - Train is allowed to move now
        st_signals.earliest_departure_reached = self._elapsed_steps >= agent.earliest_departure
381

382
383
        # Stop Action Given
        st_signals.stop_action_given = (preprocessed_action == RailEnvActions.STOP_MOVING)
384

385
        # Valid Movement action Given
Dipam Chakraborty's avatar
Dipam Chakraborty committed
386
        st_signals.valid_movement_action_given = preprocessed_action.is_moving_action() and movement_allowed
387

388
        # Target Reached
389
        st_signals.target_reached = env_utils.fast_position_equal(agent.position, agent.target)
390

391
        # Movement conflict - Multiple trains trying to move into same cell
Dipam Chakraborty's avatar
Dipam Chakraborty committed
392
393
        # If speed counter is not in cell exit, the train can enter the cell
        st_signals.movement_conflict = (not movement_allowed) and agent.speed_counter.is_cell_exit
Erik Nygren's avatar
Erik Nygren committed
394

395
        return st_signals
u214892's avatar
u214892 committed
396

397
398
399
400
401
402
403
404
405
406
    def _handle_end_reward(self, agent: EnvAgent) -> int:
        '''
        Handles end-of-episode reward for a particular agent.

        Parameters
        ----------
        agent : EnvAgent
        '''
        reward = None
        # agent done? (arrival_time is not None)
407
        if agent.state == TrainState.DONE:
408
409
410
411
412
413
414
            # if agent arrived earlier or on time = 0
            # if agent arrived later = -ve reward based on how late
            reward = min(agent.latest_arrival - agent.arrival_time, 0)

        # Agents not done (arrival_time is None)
        else:
            # CANCELLED check (never departed)
415
            if (agent.state.is_off_map_state()):
416
                reward = -1 * self.cancellation_factor * \
417
                    (agent.get_travel_time_on_shortest_path(self.distance_map) + self.cancellation_time_buffer)
418
419

            # Departed but never reached
420
            if (agent.state.is_on_map_state()):
421
422
423
424
                reward = agent.get_current_delay(self._elapsed_steps, self.distance_map)
        
        return reward

Dipam Chakraborty's avatar
Dipam Chakraborty committed
425
    def preprocess_action(self, action, agent):
426
        """
Dipam Chakraborty's avatar
Dipam Chakraborty committed
427
428
429
430
        Preprocess the provided action
            * Change to DO_NOTHING if illegal action
            * Block all actions when in waiting state
            * Check MOVE_LEFT/MOVE_RIGHT actions on current position else try MOVE_FORWARD
431
        """
432
        action = action_preprocessing.preprocess_raw_action(action, agent.state, agent.action_saver.saved_action)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
433
        action = action_preprocessing.preprocess_action_when_waiting(action, agent.state)
434

Dipam Chakraborty's avatar
Dipam Chakraborty committed
435
436
437
438
439
440
        # Try moving actions on current position
        current_position, current_direction = agent.position, agent.direction
        if current_position is None: # Agent not added on map yet
            current_position, current_direction = agent.initial_position, agent.initial_direction
        
        action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction)
441
442
443
444
445

        # Check transitions, bounts for executing the action in the given position and directon
        if not check_valid_action(action, self.rail, current_position, current_direction):
            action = RailEnvActions.STOP_MOVING

Dipam Chakraborty's avatar
Dipam Chakraborty committed
446
        return action
447
448
    
    def clear_rewards_dict(self):
449
450
        """ Reset the rewards dictionary """
        self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))}
451

452
453
454
455
456
457
458
459
    def get_info_dict(self):
        """ 
        Returns dictionary of infos for all agents 
        dict_keys : action_required - 
                    malfunction - Counter value for malfunction > 0 means train is in malfunction
                    speed - Speed of the train
                    state - State from the trains's state machine
        """
460
        info_dict = {
461
462
            'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)},
            'malfunction': {
463
                i: agent.malfunction_handler.malfunction_down_counter for i, agent in enumerate(self.agents)
464
465
466
            },
            'speed': {i: agent.speed_counter.speed for i, agent in enumerate(self.agents)},
            'state': {i: agent.state for i, agent in enumerate(self.agents)}
467
        }
468
        return info_dict
469
470
    
    def update_step_rewards(self, i_agent):
471
472
473
        """
        Update the rewards dict for agent id i_agent for every timestep
        """
474
        pass
475

476
    def end_of_episode_update(self, have_all_agents_ended):
477
478
479
480
        """ 
        Updates made when episode ends
        Parameters: have_all_agents_ended - Indicates if all agents have reached done state
        """
481
482
        if have_all_agents_ended or \
           ( (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps)):
483

484
            for i_agent, agent in enumerate(self.agents):
485
486
487
488
489
                
                reward = self._handle_end_reward(agent)
                self.rewards_dict[i_agent] += reward
                
                self.dones[i_agent] = True
spiglerg's avatar
spiglerg committed
490

491
            self.dones["__all__"] = True
492

493
    def handle_done_state(self, agent):
494
        """ Any updates to agent to be made in Done state """
495
496
497
498
        if agent.state == TrainState.DONE:
            agent.arrival_time = self._elapsed_steps
            if self.remove_agents_at_target:
                agent.position = None
499

500
501
502
503
504
    def step(self, action_dict_: Dict[int, RailEnvActions]):
        """
        Updates rewards for the agents at a step.
        """
        self._elapsed_steps += 1
505

506
507
508
        # Not allowed to step further once done
        if self.dones["__all__"]:
            raise Exception("Episode is done, cannot call step()")
509

510
        self.clear_rewards_dict()
511

512
        have_all_agents_ended = True # Boolean flag to check if all agents are done
513

514
        self.motionCheck = ac.MotionCheck()  # reset the motion check
515

516
        temp_transition_data = {}
517
        
Dipam Chakraborty's avatar
Dipam Chakraborty committed
518
519
        for agent in self.agents:
            i_agent = agent.handle
Dipam Chakraborty's avatar
Dipam Chakraborty committed
520
521
            agent.old_position = agent.position
            agent.old_direction = agent.direction
Dipam Chakraborty's avatar
Dipam Chakraborty committed
522
523
            # Generate malfunction
            agent.malfunction_handler.generate_malfunction(self.malfunction_generator, self.np_random)
524

525
            # Get action for the agent
526
            action = action_dict_.get(i_agent, RailEnvActions.DO_NOTHING)
527

Dipam Chakraborty's avatar
Dipam Chakraborty committed
528
            preprocessed_action = self.preprocess_action(action, agent)
529

530
            # Save moving actions in not already saved
Dipam Chakraborty's avatar
Dipam Chakraborty committed
531
            agent.action_saver.save_action_if_allowed(preprocessed_action, agent.state)
532

533
534
            # Train's next position can change if current stopped in a fractional speed or train is at cell's exit
            position_update_allowed = (agent.speed_counter.is_cell_exit or agent.state == TrainState.STOPPED)
535

536
537
            # Calculate new position
            # Add agent to the map if not on it yet
Dipam Chakraborty's avatar
Dipam Chakraborty committed
538
            if agent.position is None and agent.action_saver.is_action_saved:
539
540
                new_position = agent.initial_position
                new_direction = agent.initial_direction
541
                
542
            # If movement is allowed apply saved action independent of other agents
543
            elif agent.action_saver.is_action_saved and position_update_allowed:
Dipam Chakraborty's avatar
Dipam Chakraborty committed
544
                saved_action = agent.action_saver.saved_action
545
                # Apply action independent of other agents and get temporary new position and direction
546
                new_position, new_direction  = env_utils.apply_action_independent(saved_action, 
Dipam Chakraborty's avatar
Dipam Chakraborty committed
547
548
549
                                                                             self.rail, 
                                                                             agent.position, 
                                                                             agent.direction)
550
                preprocessed_action = saved_action
Dipam Chakraborty's avatar
Dipam Chakraborty committed
551
            else:
552
                new_position, new_direction = agent.position, agent.direction
553

554
            temp_transition_data[i_agent] = env_utils.AgentTransitionData(position=new_position,
555
556
                                                                direction=new_direction,
                                                                preprocessed_action=preprocessed_action)
557
            
558
            # This is for storing and later checking for conflicts of agents trying to occupy same cell                                                    
559
            self.motionCheck.addAgent(i_agent, agent.position, new_position)
560

561
        # Find conflicts between trains trying to occupy same cell
562
        self.motionCheck.find_conflicts()
563
        
Dipam Chakraborty's avatar
Dipam Chakraborty committed
564
565
        for agent in self.agents:
            i_agent = agent.handle
566

567
            ## Update positions
Dipam Chakraborty's avatar
Dipam Chakraborty committed
568
569
            if agent.malfunction_handler.in_malfunction:
                movement_allowed = False
570
            else:
571
                movement_allowed = self.motionCheck.check_motion(i_agent, agent.position) 
572

573

574
575
576

            # Fetch the saved transition data
            agent_transition_data = temp_transition_data[i_agent]
577
            preprocessed_action = agent_transition_data.preprocessed_action
578

579
            ## Update states
580
581
582
            state_transition_signals = self.generate_state_transition_signals(agent, preprocessed_action, movement_allowed)
            agent.state_machine.set_transition_signals(state_transition_signals)
            agent.state_machine.step()
583

584
585
586
587
            # Needed when not removing agents at target
            movement_allowed = movement_allowed and agent.state != TrainState.DONE

            # Agent is being added to map
588
589
590
591
            if agent.state.is_on_map_state():
                if agent.state_machine.previous_state.is_off_map_state():
                    agent.position = agent.initial_position
                    agent.direction = agent.initial_direction
592
            # Speed counter completes
593
594
595
596
                elif movement_allowed and (agent.speed_counter.is_cell_exit):
                    agent.position = agent_transition_data.position
                    agent.direction = agent_transition_data.direction
                    agent.state_machine.update_if_reached(agent.position, agent.target)
597

598
            # Off map or on map state and position should match
599
            env_utils.state_position_sync_check(agent.state, agent.position, agent.handle)
600

601
602
            # Handle done state actions, optionally remove agents
            self.handle_done_state(agent)
603
604
            
            have_all_agents_ended &= (agent.state == TrainState.DONE)
605

606
            ## Update rewards
607
            self.update_step_rewards(i_agent)
608

609
            ## Update counters (malfunction and speed)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
610
            agent.speed_counter.update_counter(agent.state, agent.old_position)
611
                                            #    agent.state_machine.previous_state)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
612
            agent.malfunction_handler.update_counter()
613

614
            # Clear old action when starting in new cell
615
            if agent.speed_counter.is_cell_entry and agent.position is not None:
616
617
                agent.action_saver.clear_saved_action()
        
618
619
        # Check if episode has ended and update rewards and dones
        self.end_of_episode_update(have_all_agents_ended)
u214892's avatar
u214892 committed
620

Dipam Chakraborty's avatar
Dipam Chakraborty committed
621
        self._update_agent_positions_map()
622

623
        return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict() 
spiglerg's avatar
spiglerg committed
624

625
    def record_timestep(self, dActions):
626
627
628
        """ 
        Record the positions and orientations of all agents in memory, in the cur_episode
        """
629
630
631
632
633
634
635
636
637
638
        list_agents_state = []
        for i_agent in range(self.get_num_agents()):
            agent = self.agents[i_agent]
            # the int cast is to avoid numpy types which may cause problems with msgpack
            # in env v2, agents may have position None, before starting
            if agent.position is None:
                pos = (0, 0)
            else:
                pos = (int(agent.position[0]), int(agent.position[1]))
            # print("pos:", pos, type(pos[0]))
639
640
            list_agents_state.append([
                    *pos, int(agent.direction), 
641
                    agent.malfunction_handler.malfunction_down_counter,  
642
643
                    int(agent.status),
                    int(agent.position in self.motionCheck.svDeadlocked)
644
                    ])
645

646
        self.cur_episode.append(list_agents_state)
647
        self.list_actions.append(dActions)
648

649
    def _get_observations(self):
650
        """
651
        Utility which returns the dictionary of observations for an agent with respect to environment
652
        """
653
        # print(f"_get_obs - num agents: {self.get_num_agents()} {list(range(self.get_num_agents()))}")
654
        self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
655
        return self.obs_dict
656

u214892's avatar
u214892 committed
657
    def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]:
658
659
660
        """
        Returns directions in which the agent can move
        """
661
        return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col))
u214892's avatar
u214892 committed
662

663
    def _exp_distirbution_synced(self, rate: float) -> float:
664
665
666
667
668
669
670
671
672
        """
        Generates sample from exponential distribution
        We need this to guarantee synchronity between different instances with same seed.
        :param rate:
        :return:
        """
        u = self.np_random.rand()
        x = - np.log(1 - u) * rate
        return x
673

674
    def _is_agent_ok(self, agent: EnvAgent) -> bool:
Erik Nygren's avatar
Erik Nygren committed
675
676
677
678
679
680
681
682
683
684
685
        """
        Check if an agent is ok, meaning it can move and is not malfuncitoinig
        Parameters
        ----------
        agent

        Returns
        -------
        True if agent is ok, False otherwise

        """
Dipam Chakraborty's avatar
Dipam Chakraborty committed
686
        return agent.malfunction_handler.in_malfunction
687
        
hagrid67's avatar
hagrid67 committed
688
689

    def save(self, filename):
690
        print("DEPRECATED call to env.save() - pls call RailEnvPersister.save()")
hagrid67's avatar
hagrid67 committed
691
        persistence.RailEnvPersister.save(self, filename)
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765

    def render(self, mode="rgb_array", gl="PGL", agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
            show_debug=False, clear_debug_text=True, show=False,
            screen_height=600, screen_width=800,
            show_observations=False, show_predictions=False,
            show_rowcols=False, return_image=True):
        """
        This methods provides the option to render the
        environment's behavior as an image or to a window.
        Parameters
        ----------
        mode

        Returns
        -------
        Image if mode is rgb_array, opens a window otherwise
        """
        if not hasattr(self, "renderer") or self.renderer is None:
            self.initialize_renderer(mode=mode, gl=gl,  # gl="TKPILSVG",
                                    agent_render_variant=agent_render_variant,
                                    show_debug=show_debug,
                                    clear_debug_text=clear_debug_text,
                                    show=show,
                                    screen_height=screen_height,  # Adjust these parameters to fit your resolution
                                    screen_width=screen_width)
        return self.update_renderer(mode=mode, show=show, show_observations=show_observations,
                                    show_predictions=show_predictions,
                                    show_rowcols=show_rowcols, return_image=return_image)

    def initialize_renderer(self, mode, gl,
                agent_render_variant,
                show_debug,
                clear_debug_text,
                show,
                screen_height,
                screen_width):
        # Initiate the renderer
        self.renderer = RenderTool(self, gl=gl,  # gl="TKPILSVG",
                                agent_render_variant=agent_render_variant,
                                show_debug=show_debug,
                                clear_debug_text=clear_debug_text,
                                screen_height=screen_height,  # Adjust these parameters to fit your resolution
                                screen_width=screen_width)  # Adjust these parameters to fit your resolution
        self.renderer.show = show
        self.renderer.reset()

    def update_renderer(self, mode, show, show_observations, show_predictions,
                    show_rowcols, return_image):
        """
        This method updates the render.
        Parameters
        ----------
        mode

        Returns
        -------
        Image if mode is rgb_array, None otherwise
        """
        image = self.renderer.render_env(show=show, show_observations=show_observations,
                                show_predictions=show_predictions,
                                show_rowcols=show_rowcols, return_image=return_image)
        if mode == 'rgb_array':
            return image[:, :, :3]

    def close(self):
        """
        This methods closes any renderer window.
        """
        if hasattr(self, "renderer") and self.renderer is not None:
            try:
                if self.renderer.show:
                    self.renderer.close_window()
            except Exception as e:
                print("Could Not close window due to:",e)
766
            self.renderer = None