diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 8ab96d3a31ac169457adad9899a466592abec705..8590922131d57cd76ba4a638b10e3e7e169a984a 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -78,6 +78,30 @@ class TreeObsForRailEnv(ObservationBuilder): self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)}) self.predicted_dir.update({t: dir_list}) self.max_prediction_depth = len(self.predicted_pos) + # Update local lookup table for all agents' positions + # ignore other agents not in the grid (only status active and done) + # self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if + # agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]} + + self.location_has_agent = {} + self.location_has_agent_direction = {} + self.location_has_agent_speed = {} + self.location_has_agent_malfunction = {} + self.location_has_agent_ready_to_depart = {} + + for _agent in self.env.agents: + if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \ + _agent.position: + self.location_has_agent[tuple(_agent.position)] = 1 + self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction + self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed'] + self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data[ + 'malfunction'] + + if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \ + _agent.initial_position: + self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \ + self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1 observations = super().get_many(handles) @@ -162,30 +186,6 @@ class TreeObsForRailEnv(ObservationBuilder): In case the target node is reached, the values are [0, 0, 0, 0, 0]. """ - # Update local lookup table for all agents' positions - # ignore other agents not in the grid (only status active and done) - # self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if - # agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]} - - self.location_has_agent = {} - self.location_has_agent_direction = {} - self.location_has_agent_speed = {} - self.location_has_agent_malfunction = {} - self.location_has_agent_ready_to_depart = {} - - for _agent in self.env.agents: - if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \ - _agent.position: - self.location_has_agent[tuple(_agent.position)] = 1 - self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction - self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed'] - self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data['malfunction'] - - if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \ - _agent.initial_position: - self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \ - self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1 - if handle > len(self.env.agents): print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents)) agent = self.env.agents[handle] # TODO: handle being treated as index diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index d9963a2276e352b2adcf1030035c7078aea374e6..22bd21625fd4cc0fb50c592bf4533e60ced00546 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -14,6 +14,7 @@ from flatland.core.env import Environment from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions from flatland.core.grid.grid4_utils import get_new_position +from flatland.core.grid.grid_utils import IntVector2D from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus from flatland.envs.distance_map import DistanceMap @@ -83,6 +84,7 @@ class RailEnv(Environment): - invalid_action_penalty = 0 - step_penalty = -alpha - global_reward = beta + - epsilon = avoid rounding errors - stop_penalty = 0 # penalty for stopping a moving agent - start_penalty = 0 # penalty for starting a stopped agent @@ -217,6 +219,9 @@ class RailEnv(Environment): self.valid_positions = None + # global numpy array of agents position, True means that there is an agent at that cell + self.agent_positions: np.ndarray = np.full((height, width), False) + def _seed(self, seed=None): self.np_random, seed = seeding.np_random(seed) return [seed] @@ -242,7 +247,7 @@ class RailEnv(Environment): agent = self.agents[handle] if agent.status == RailAgentStatus.READY_TO_DEPART and self.cell_free(agent.initial_position): agent.status = RailAgentStatus.ACTIVE - agent.position = agent.initial_position + self._set_agent_to_initial_position(agent, agent.initial_position) def restart_agents(self): """ Reset the agents to their starting positions defined in agents_static @@ -275,6 +280,24 @@ class RailEnv(Environment): alpha = 2 return int(timedelay_factor * alpha * (width + height + ratio_nr_agents_to_nr_cities)) + def action_required(self, agent): + """ + Check if an agent needs to provide an action + + Parameters + ---------- + agent: RailEnvAgent + Agent we want to check + + Returns + ------- + True: Agent needs to provide an action + False: Agent cannot provide an action + """ + return (agent.status == RailAgentStatus.READY_TO_DEPART or ( + agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0, + rtol=1e-03))) + def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False, random_seed: bool = None) -> (Dict, Dict): """ @@ -339,6 +362,8 @@ class RailEnv(Environment): else: self._max_episode_steps = self.compute_max_episode_steps(width=self.width, height=self.height) + self.agent_positions = np.full((self.height, self.width), False) + self.restart_agents() if activate_agents: @@ -370,10 +395,7 @@ class RailEnv(Environment): self.distance_map.reset(self.agents, self.rail) info_dict: Dict = { - 'action_required': { - i: (agent.status == RailAgentStatus.READY_TO_DEPART or ( - agent.status == RailAgentStatus.ACTIVE and agent.speed_data['position_fraction'] == 0.0)) - for i, agent in enumerate(self.agents)}, + 'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)}, 'malfunction': { i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents()) }, @@ -390,26 +412,27 @@ class RailEnv(Environment): """ agent = self.agents[i_agent] - # Skip agents that cannot break - # TODO: Make a better malfunction model such that not always the same agents break. - if agent.malfunction_data['malfunction_rate'] < 1: - return False + # Decrease counter for next event only if agent is currently not broken and agent has a malfunction rate + if agent.malfunction_data['malfunction_rate'] >= 1 and agent.malfunction_data['next_malfunction'] > 0 and \ + agent.malfunction_data['malfunction'] < 1: + agent.malfunction_data['next_malfunction'] -= 1 - # If agent is currently working and next malfunction time is reached we set it to malfunctioning - if 1 > agent.malfunction_data['malfunction'] and agent.malfunction_data['next_malfunction'] < 1: + # Only agents that have a positive rate for malfunctions and are not currently broken are considered + # If counter has come to zero --> Agent has malfunction + # set next malfunction time and duration of current malfunction + if agent.malfunction_data['malfunction_rate'] >= 1 and 1 > agent.malfunction_data['malfunction'] and \ + agent.malfunction_data['next_malfunction'] < 1: # Increase number of malfunctions agent.malfunction_data['nr_malfunctions'] += 1 - # Next malfunction in number of steps + # Next malfunction in number of stops next_breakdown = int( self._exp_distirbution_synced(rate=agent.malfunction_data['malfunction_rate'])) agent.malfunction_data['next_malfunction'] = max(next_breakdown, 1) - # Duration of current malfunction num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken, - self.max_number_of_steps_broken + 1) + self.max_number_of_steps_broken + 1) + 1 agent.malfunction_data['malfunction'] = num_broken_steps - # Remember current moving state of the agent agent.malfunction_data['moving_before_malfunction'] = agent.moving return True @@ -429,16 +452,17 @@ class RailEnv(Environment): # Nothing left to do with broken agent return True - - # Decrease counter for next event only if agent is currently not broken and agent has a malfunction rate - if agent.malfunction_data['malfunction_rate'] >= 1 and agent.malfunction_data['next_malfunction'] > 0 and \ - agent.malfunction_data['malfunction'] < 1: - agent.malfunction_data['next_malfunction'] -= 1 - return False def step(self, action_dict_: Dict[int, RailEnvActions]): + """ + Updates rewards for the agents at a step. + Parameters + ---------- + action_dict_ : Dict[int,RailEnvActions] + + """ self._elapsed_steps += 1 # If we're done, set reward and info_dict and step() is done. @@ -479,10 +503,7 @@ class RailEnv(Environment): have_all_agents_ended &= (agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]) # Build info dict - info_dict["action_required"][i_agent] = \ - (agent.status == RailAgentStatus.READY_TO_DEPART or ( - agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0, - rtol=1e-03))) + info_dict["action_required"][i_agent] = self.action_required(agent) info_dict["malfunction"][i_agent] = agent.malfunction_data['malfunction'] info_dict["speed"][i_agent] = agent.speed_data['speed'] info_dict["status"][i_agent] = agent.status @@ -520,7 +541,7 @@ class RailEnv(Environment): if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position): agent.status = RailAgentStatus.ACTIVE - agent.position = agent.initial_position + self._set_agent_to_initial_position(agent, agent.initial_position) self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] return else: @@ -615,7 +636,7 @@ class RailEnv(Environment): assert new_cell_valid assert transition_valid if cell_free: - agent.position = new_position + self._move_agent_to_new_position(agent, new_position) agent.direction = new_direction agent.speed_data['position_fraction'] = 0.0 @@ -624,16 +645,54 @@ class RailEnv(Environment): agent.status = RailAgentStatus.DONE self.dones[i_agent] = True agent.moving = False - - if self.remove_agents_at_target: - agent.position = None - agent.status = RailAgentStatus.DONE_REMOVED + self._remove_agent_from_scene(agent) else: self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] else: # step penalty if not moving (stopped now or before) self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + def _set_agent_to_initial_position(self, agent: EnvAgent, new_position: IntVector2D): + """ + Sets the agent to its initial position. Updates the agent object and the position + of the agent inside the global agent_position numpy array + + Parameters + ------- + agent: EnvAgent object + new_position: IntVector2D + """ + agent.position = new_position + self.agent_positions[agent.position] = True + + def _move_agent_to_new_position(self, agent: EnvAgent, new_position: IntVector2D): + """ + Move the agent to the a new position. Updates the agent object and the position + of the agent inside the global agent_position numpy array + + Parameters + ------- + agent: EnvAgent object + new_position: IntVector2D + """ + agent.position = new_position + self.agent_positions[agent.old_position] = False + self.agent_positions[agent.position] = True + + def _remove_agent_from_scene(self, agent: EnvAgent): + """ + Remove the agent from the scene. Updates the agent object and the position + of the agent inside the global agent_position numpy array + + Parameters + ------- + agent: EnvAgent object + """ + self.agent_positions[agent.position] = False + if self.remove_agents_at_target: + agent.position = None + agent.status = RailAgentStatus.DONE_REMOVED + def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent): """ @@ -670,16 +729,32 @@ class RailEnv(Environment): (*agent.position, agent.direction), new_direction) - # Check the new position is not the same as any of the existing agent positions - # (including itself, for simplicity, since it is moving) - cell_free = self.cell_free(new_position) + + # only call cell_free() if new cell is inside the scene + if new_cell_valid: + # Check the new position is not the same as any of the existing agent positions + # (including itself, for simplicity, since it is moving) + cell_free = self.cell_free(new_position) + else: + # if new cell is outside of scene -> cell_free is False + cell_free = False return cell_free, new_cell_valid, new_direction, new_position, transition_valid - def cell_free(self, position): + def cell_free(self, position: IntVector2D) -> bool: + """ + Utility to check if a cell is free + + Parameters: + -------- + position : Tuple[int, int] + + Returns + ------- + bool + is the cell free or not? - agent_positions = [agent.position for agent in self.agents if agent.position is not None] - ret = len(agent_positions) == 0 or not np.any(np.equal(position, agent_positions).all(1)) - return ret + """ + return not self.agent_positions[position] def check_action(self, agent: EnvAgent, action: RailEnvActions): """ @@ -722,13 +797,35 @@ class RailEnv(Environment): return new_direction, transition_valid def _get_observations(self): + """ + Utility which returns the observations for an agent with respect to environment + + Returns + ------ + Dict object + """ self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents()))) return self.obs_dict def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]: + """ + Returns directions in which the agent can move + + Parameters: + --------- + row : int + col : int + + Returns: + ------- + List[int] + """ return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col)) def get_full_state_msg(self): + """ + Returns state of environment in msgpack object + """ grid_data = self.rail.grid.tolist() agent_static_data = [agent.to_list() for agent in self.agents_static] agent_data = [agent.to_list() for agent in self.agents] @@ -742,12 +839,22 @@ class RailEnv(Environment): return msgpack.packb(msg_data, use_bin_type=True) def get_agent_state_msg(self): + """ + Returns agents information in msgpack object + """ agent_data = [agent.to_list() for agent in self.agents] msg_data = { "agents": agent_data} return msgpack.packb(msg_data, use_bin_type=True) def set_full_state_msg(self, msg_data): + """ + Sets environment state with msgdata object passed as argument + + Parameters + ------- + msg_data: msgpack object + """ data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8') self.rail.grid = np.array(data["grid"]) # agents are always reset as not moving @@ -760,6 +867,13 @@ class RailEnv(Environment): self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) def set_full_state_dist_msg(self, msg_data): + """ + Sets environment grid state and distance map with msgdata object passed as argument + + Parameters + ------- + msg_data: msgpack object + """ data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8') self.rail.grid = np.array(data["grid"]) # agents are always reset as not moving @@ -774,6 +888,9 @@ class RailEnv(Environment): self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) def get_full_state_dist_msg(self): + """ + Returns environment information with distance map information as msgpack object + """ grid_data = self.rail.grid.tolist() agent_static_data = [agent.to_list() for agent in self.agents_static] agent_data = [agent.to_list() for agent in self.agents] @@ -791,6 +908,14 @@ class RailEnv(Environment): return msgpack.packb(msg_data, use_bin_type=True) def save(self, filename, save_distance_maps=False): + """ + Saves environment and distance map information in a file + + Parameters: + --------- + filename: string + save_distance_maps: bool + """ if save_distance_maps is True: if self.distance_map.get() is not None: if len(self.distance_map.get()) > 0: @@ -807,14 +932,31 @@ class RailEnv(Environment): file_out.write(self.get_full_state_msg()) def load(self, filename): + """ + Load environment with distance map from a file + + Parameters: + ------- + filename: string + """ with open(filename, "rb") as file_in: load_data = file_in.read() self.set_full_state_dist_msg(load_data) def load_pkl(self, pkl_data): + """ + Load environment with distance map from a pickle file + + Parameters: + ------- + pkl_data: pickle file + """ self.set_full_state_msg(pkl_data) def load_resource(self, package, resource): + """ + Load environment with distance map from a binary + """ from importlib_resources import read_binary load_data = read_binary(package, resource) self.set_full_state_msg(load_data) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 028928aeec2c08c7b9035d875a874987dce79896..3e90128c74bda860d8a4c75d71652071660482a3 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -343,6 +343,17 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11, seed=1) -> R template = [template[-1]] + template[:-1] def get_matching_templates(template): + """ + Returns a list of possible transition maps for a given template + + Parameters: + ------ + template:List[int] + + Returns: + ------ + List[int] + """ ret = [] for i in range(len(transitions_templates_)): is_match = True diff --git a/notebooks/simple_example1_env_from_tuple.ipynb b/notebooks/simple_example1_env_from_tuple.ipynb index 2d13585156b447fe4f9eed07834b3dc04ccf84d6..317a5b63ef171691c0324d38d82ad543e295155c 100644 --- a/notebooks/simple_example1_env_from_tuple.ipynb +++ b/notebooks/simple_example1_env_from_tuple.ipynb @@ -10,9 +10,21 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "SystemError", + "evalue": "Parent module '' not loaded, cannot perform relative import", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mSystemError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m<ipython-input-2-b6a25a9cfbbb>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflatland\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrail_generators\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mrail_from_manual_specifications_generator\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mflatland\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobservations\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mTreeObsForRailEnv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mflatland\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrail_env\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mRailEnv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mflatland\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrendertools\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mRenderTool\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mPIL\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mImage\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mSystemError\u001b[0m: Parent module '' not loaded, cannot perform relative import" + ] + } + ], "source": [ "from flatland.envs.rail_generators import rail_from_manual_specifications_generator\n", "from flatland.envs.observations import TreeObsForRailEnv\n", @@ -24,7 +36,9 @@ { "cell_type": "code", "execution_count": 2, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)],\n", @@ -83,7 +97,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.3" + "version": "3.5.2" }, "latex_envs": { "LaTeX_envs_menu_present": true, diff --git a/notebooks/simple_example2_generate_random_rail.ipynb b/notebooks/simple_example2_generate_random_rail.ipynb index 19b854ee15d8dd1e19361f58a552eba617b19b67..bfa2c877ef0ecec8c596592c13c08db287661c78 100644 --- a/notebooks/simple_example2_generate_random_rail.ipynb +++ b/notebooks/simple_example2_generate_random_rail.ipynb @@ -88,7 +88,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.5" + "version": "3.5.2" }, "latex_envs": { "LaTeX_envs_menu_present": true,