diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index a56ea043bc3cb9883cfc183fe4bf34fdce79f1e3..22815f33b2ab0df490cddce489c572b905a5e555 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -7,13 +7,11 @@ from typing import List, Optional, Dict, Tuple import numpy as np from gym.utils import seeding -from dataclasses import dataclass from flatland.utils.rendertools import RenderTool, AgentRenderVariant from flatland.core.env import Environment from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4 import Grid4Transitions -from flatland.core.grid.grid4_utils import get_new_position from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgent from flatland.envs.distance_map import DistanceMap @@ -30,8 +28,8 @@ from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.timetable_generators import timetable_generator from flatland.envs.step_utils.states import TrainState, StateTransitionSignals -from flatland.envs.step_utils import transition_utils from flatland.envs.step_utils import action_preprocessing +from flatland.envs.step_utils import env_utils class RailEnv(Environment): """ @@ -110,7 +108,6 @@ class RailEnv(Environment): remove_agents_at_target=True, random_seed=1, record_steps=False, - close_following=True ): """ Environment init. @@ -178,16 +175,12 @@ class RailEnv(Environment): self.remove_agents_at_target = remove_agents_at_target - self.rewards = [0] * number_of_agents - self.done = False self.obs_builder = obs_builder_object self.obs_builder.set_env(self) self._max_episode_steps: Optional[int] = None self._elapsed_steps = 0 - self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False) - self.obs_dict = {} self.rewards_dict = {} self.dev_obs_dict = {} @@ -205,10 +198,7 @@ class RailEnv(Environment): if self.random_seed: self._seed(seed=random_seed) - self.valid_positions = None - - # global numpy array of agents position, True means that there is an agent at that cell - self.agent_positions: np.ndarray = np.full((height, width), False) + self.agent_positions = None # save episode timesteps ie agent positions, orientations. (not yet actions / observations) self.record_steps = record_steps # whether to save timesteps @@ -216,11 +206,8 @@ class RailEnv(Environment): self.cur_episode = [] self.list_actions = [] # save actions in here - self.close_following = close_following # use close following logic self.motionCheck = ac.MotionCheck() - self.agent_helpers = {} - def _seed(self, seed=None): self.np_random, seed = seeding.np_random(seed) random.seed(seed) @@ -229,7 +216,7 @@ class RailEnv(Environment): # no more agent_handles def get_agent_handles(self): return range(self.get_num_agents()) - + def get_num_agents(self) -> int: return len(self.agents) @@ -337,9 +324,6 @@ class RailEnv(Environment): agent.latest_arrival = timetable.latest_arrivals[agent_i] else: self.distance_map.reset(self.agents, self.rail) - - # Agent Positions Map - self.agent_positions = np.zeros((self.height, self.width), dtype=int) - 1 # Reset agents to initial states self.reset_agents() @@ -347,7 +331,10 @@ class RailEnv(Environment): self.num_resets += 1 self._elapsed_steps = 0 - # TODO perhaps dones should be part of each agent. + # Agent positions map + self.agent_positions = np.zeros((self.height, self.width), dtype=int) - 1 + self._update_agent_positions_map(ignore_old_positions=False) + self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) # Reset the state of the observation builder with the new environment @@ -362,14 +349,16 @@ class RailEnv(Environment): if hasattr(self, "renderer") and self.renderer is not None: self.renderer = None return observation_dict, info_dict - - def apply_action_independent(self, action, rail, position, direction): - if action.is_moving_action(): - new_direction, _ = transition_utils.check_action(action, position, direction, rail) - new_position = get_new_position(position, new_direction) - else: - new_position, new_direction = position, direction - return new_position, new_direction + + + def _update_agent_positions_map(self, ignore_old_positions=True): + """ Update the agent_positions array for agents that changed positions """ + for agent in self.agents: + if not ignore_old_positions or agent.old_position != agent.position: + self.agent_positions[agent.position] = agent.handle + if agent.old_position is not None: + self.agent_positions[agent.old_position] = -1 + def generate_state_transition_signals(self, agent, preprocessed_action, movement_allowed): """ Generate State Transitions Signals used in the state machine """ @@ -391,7 +380,7 @@ class RailEnv(Environment): st_signals.valid_movement_action_given = preprocessed_action.is_moving_action() and movement_allowed # Target Reached - st_signals.target_reached = fast_position_equal(agent.position, agent.target) + st_signals.target_reached = env_utils.fast_position_equal(agent.position, agent.target) # Movement conflict - Multiple trains trying to move into same cell # If speed counter is not in cell exit, the train can enter the cell @@ -449,11 +438,18 @@ class RailEnv(Environment): """ Reset the rewards dictionary """ self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))} - def get_info_dict(self): # TODO Important : Update this + def get_info_dict(self): + """ + Returns dictionary of infos for all agents + dict_keys : action_required - + malfunction - Counter value for malfunction > 0 means train is in malfunction + speed - Speed of the train + state - State from the trains's state machine + """ info_dict = { 'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)}, 'malfunction': { - i: agent.malfunction_data['malfunction'] for i, agent in enumerate(self.agents) + i: agent.malfunction_handler.malfunction_down_counter for i, agent in enumerate(self.agents) }, 'speed': {i: agent.speed_counter.speed for i, agent in enumerate(self.agents)}, 'state': {i: agent.state for i, agent in enumerate(self.agents)} @@ -461,9 +457,16 @@ class RailEnv(Environment): return info_dict def update_step_rewards(self, i_agent): + """ + Update the rewards dict for agent id i_agent for every timestep + """ pass def end_of_episode_update(self, have_all_agents_ended): + """ + Updates made when episode ends + Parameters: have_all_agents_ended - Indicates if all agents have reached done state + """ if have_all_agents_ended or \ ( (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps)): @@ -477,6 +480,7 @@ class RailEnv(Environment): self.dones["__all__"] = True def handle_done_state(self, agent): + """ Any updates to agent to be made in Done state """ if agent.state == TrainState.DONE: agent.arrival_time = self._elapsed_steps if self.remove_agents_at_target: @@ -528,7 +532,7 @@ class RailEnv(Environment): elif agent.action_saver.is_action_saved and position_update_allowed: saved_action = agent.action_saver.saved_action # Apply action independent of other agents and get temporary new position and direction - new_position, new_direction = self.apply_action_independent(saved_action, + new_position, new_direction = env_utils.apply_action_independent(saved_action, self.rail, agent.position, agent.direction) @@ -536,7 +540,7 @@ class RailEnv(Environment): else: new_position, new_direction = agent.position, agent.direction - temp_transition_data[i_agent] = AgentTransitionData(position=new_position, + temp_transition_data[i_agent] = env_utils.AgentTransitionData(position=new_position, direction=new_direction, preprocessed_action=preprocessed_action) @@ -571,7 +575,7 @@ class RailEnv(Environment): agent.state_machine.step() # Off map or on map state and position should match - state_position_sync_check(agent.state, agent.position, agent.handle) + env_utils.state_position_sync_check(agent.state, agent.position, agent.handle) # Handle done state actions, optionally remove agents self.handle_done_state(agent) @@ -593,11 +597,14 @@ class RailEnv(Environment): # Check if episode has ended and update rewards and dones self.end_of_episode_update(have_all_agents_ended) + self._update_agent_positions_map + return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict() def record_timestep(self, dActions): - ''' Record the positions and orientations of all agents in memory, in the cur_episode - ''' + """ + Record the positions and orientations of all agents in memory, in the cur_episode + """ list_agents_state = [] for i_agent in range(self.get_num_agents()): agent = self.agents[i_agent] @@ -610,7 +617,7 @@ class RailEnv(Environment): # print("pos:", pos, type(pos[0])) list_agents_state.append([ *pos, int(agent.direction), - agent.malfunction_data["malfunction"], + agent.malfunction_handler.malfunction_down_counter, int(agent.status), int(agent.position in self.motionCheck.svDeadlocked) ]) @@ -620,11 +627,7 @@ class RailEnv(Environment): def _get_observations(self): """ - Utility which returns the observations for an agent with respect to environment - - Returns - ------ - Dict object + Utility which returns the dictionary of observations for an agent with respect to environment """ # print(f"_get_obs - num agents: {self.get_num_agents()} {list(range(self.get_num_agents()))}") self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents()))) @@ -633,15 +636,6 @@ class RailEnv(Environment): def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]: """ Returns directions in which the agent can move - - Parameters: - --------- - row : int - col : int - - Returns: - ------- - List[int] """ return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col)) @@ -669,9 +663,10 @@ class RailEnv(Environment): """ return agent.malfunction_handler.in_malfunction + def save(self, filename): - print("deprecated call to env.save() - pls call RailEnvPersister.save()") + print("DEPRECATED call to env.save() - pls call RailEnvPersister.save()") persistence.RailEnvPersister.save(self, filename) def render(self, mode="rgb_array", gl="PGL", agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, @@ -747,31 +742,4 @@ class RailEnv(Environment): self.renderer.close_window() except Exception as e: print("Could Not close window due to:",e) - self.renderer = None - - -@dataclass(repr=True) -class AgentTransitionData: - """ Class for keeping track of temporary agent data for position update """ - position : Tuple[int, int] - direction : Grid4Transitions - preprocessed_action : RailEnvActions - - -# Adrian Egli performance fix (the fast methods brings more than 50%) -def fast_isclose(a, b, rtol): - return (a < (b + rtol)) or (a < (b - rtol)) - -def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool: - if pos_1 is None: # TODO: Dipam - Consider making default of agent.position as (-1, -1) instead of None - return False - else: - return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1] - -def state_position_sync_check(state, position, i_agent): - if state.is_on_map_state() and position is None: - raise ValueError("Agent ID {} Agent State {} is on map Agent Position {} if off map ".format( - i_agent, str(state), str(position) )) - elif state.is_off_map_state() and position is not None: - raise ValueError("Agent ID {} Agent State {} is off map Agent Position {} if on map ".format( - i_agent, str(state), str(position) )) + self.renderer = None \ No newline at end of file diff --git a/flatland/envs/step_utils/env_utils.py b/flatland/envs/step_utils/env_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6eec2129059dd168dc4f9ad742c3ad7e7095e164 --- /dev/null +++ b/flatland/envs/step_utils/env_utils.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass +from typing import Tuple + +from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.step_utils import transition_utils +from flatland.envs.rail_env_action import RailEnvActions +from flatland.core.grid.grid4 import Grid4Transitions + +@dataclass(repr=True) +class AgentTransitionData: + """ Class for keeping track of temporary agent data for position update """ + position : Tuple[int, int] + direction : Grid4Transitions + preprocessed_action : RailEnvActions + +# Adrian Egli performance fix (the fast methods brings more than 50%) +def fast_isclose(a, b, rtol): + return (a < (b + rtol)) or (a < (b - rtol)) + +def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool: + if pos_1 is None: + return False + else: + return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1] + +def apply_action_independent(action, rail, position, direction): + """ Apply the action on the train regardless of locations of other trains + Checks for valid cells to move and valid rail transitions + --------------------------------------------------------------------- + Parameters: action - Action to execute + rail - Flatland env.rail object + position - current position of the train + direction - current direction of the train + --------------------------------------------------------------------- + Returns: new_position - New position after applying the action + new_direction - New direction after applying the action + """ + if action.is_moving_action(): + new_direction, _ = transition_utils.check_action(action, position, direction, rail) + new_position = get_new_position(position, new_direction) + else: + new_position, new_direction = position, direction + return new_position, new_direction + +def state_position_sync_check(state, position, i_agent): + """ Check for whether on map and off map states are matching with position """ + if state.is_on_map_state() and position is None: + raise ValueError("Agent ID {} Agent State {} is on map Agent Position {} if off map ".format( + i_agent, str(state), str(position) )) + elif state.is_off_map_state() and position is not None: + raise ValueError("Agent ID {} Agent State {} is off map Agent Position {} if on map ".format( + i_agent, str(state), str(position) )) \ No newline at end of file