From 9edc0ff85133c21d448d0c1a7e34c9c06c2fd121 Mon Sep 17 00:00:00 2001 From: u229589 <christian.baumberger@sbb.ch> Date: Tue, 29 Oct 2019 08:27:11 +0100 Subject: [PATCH] Merge branch 'master' of gitlab.aicrowd.com:flatland/flatland into 259-reduce-computational-complexity-of-env-step # Conflicts: # flatland/envs/rail_env.py --- flatland/envs/rail_env.py | 116 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 153f668b..32dd70e5 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -84,6 +84,7 @@ class RailEnv(Environment): - invalid_action_penalty = 0 - step_penalty = -alpha - global_reward = beta + - epsilon = avoid rounding errors - stop_penalty = 0 # penalty for stopping a moving agent - start_penalty = 0 # penalty for starting a stopped agent @@ -217,6 +218,8 @@ class RailEnv(Environment): # Reset environment self.valid_positions = None + + # global numpy array of agents position, True means that there is an agent at that cell self.agent_positions: np.ndarray = np.full((height, width), False) def _seed(self, seed=None): @@ -437,7 +440,14 @@ class RailEnv(Environment): return False def step(self, action_dict_: Dict[int, RailEnvActions]): + """ + Updates rewards for the agents at a step. + Parameters + ---------- + action_dict_ : Dict[int,RailEnvActions] + + """ self._elapsed_steps += 1 # If we're done, set reward and info_dict and step() is done. @@ -631,15 +641,41 @@ class RailEnv(Environment): self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] def _set_agent_to_initial_position(self, agent: EnvAgent, new_position: IntVector2D): + """ + Sets the agent to its initial position. Updates the agent object and the position + of the agent inside the global agent_position numpy array + + Parameters + ------- + agent: EnvAgent object + new_position: IntVector2D + """ agent.position = new_position self.agent_positions[agent.position] = True def _move_agent_to_new_position(self, agent: EnvAgent, new_position: IntVector2D): + """ + Move the agent to the a new position. Updates the agent object and the position + of the agent inside the global agent_position numpy array + + Parameters + ------- + agent: EnvAgent object + new_position: IntVector2D + """ agent.position = new_position self.agent_positions[agent.old_position] = False self.agent_positions[agent.position] = True def _remove_agent_from_scene(self, agent: EnvAgent): + """ + Remove the agent from the scene. Updates the agent object and the position + of the agent inside the global agent_position numpy array + + Parameters + ------- + agent: EnvAgent object + """ self.agent_positions[agent.position] = False if self.remove_agents_at_target: agent.position = None @@ -687,6 +723,19 @@ class RailEnv(Environment): return cell_free, new_cell_valid, new_direction, new_position, transition_valid def cell_free(self, position: IntVector2D) -> bool: + """ + Utility to check if a cell is free + + Parameters: + -------- + position : Tuple[int, int] + + Returns + ------- + bool + is the cell free or not? + + """ try: return not self.agent_positions[position] except IndexError as error: @@ -734,13 +783,35 @@ class RailEnv(Environment): return new_direction, transition_valid def _get_observations(self): + """ + Utility which returns the observations for an agent with respect to environment + + Returns + ------ + Dict object + """ self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents()))) return self.obs_dict def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]: + """ + Returns directions in which the agent can move + + Parameters: + --------- + row : int + col : int + + Returns: + ------- + List[int] + """ return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col)) def get_full_state_msg(self): + """ + Returns state of environment in msgpack object + """ grid_data = self.rail.grid.tolist() agent_static_data = [agent.to_list() for agent in self.agents_static] agent_data = [agent.to_list() for agent in self.agents] @@ -754,12 +825,22 @@ class RailEnv(Environment): return msgpack.packb(msg_data, use_bin_type=True) def get_agent_state_msg(self): + """ + Returns agents information in msgpack object + """ agent_data = [agent.to_list() for agent in self.agents] msg_data = { "agents": agent_data} return msgpack.packb(msg_data, use_bin_type=True) def set_full_state_msg(self, msg_data): + """ + Sets environment state with msgdata object passed as argument + + Parameters + ------- + msg_data: msgpack object + """ data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8') self.rail.grid = np.array(data["grid"]) # agents are always reset as not moving @@ -772,6 +853,13 @@ class RailEnv(Environment): self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) def set_full_state_dist_msg(self, msg_data): + """ + Sets environment grid state and distance map with msgdata object passed as argument + + Parameters + ------- + msg_data: msgpack object + """ data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8') self.rail.grid = np.array(data["grid"]) # agents are always reset as not moving @@ -786,6 +874,9 @@ class RailEnv(Environment): self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) def get_full_state_dist_msg(self): + """ + Returns environment information with distance map information as msgpack object + """ grid_data = self.rail.grid.tolist() agent_static_data = [agent.to_list() for agent in self.agents_static] agent_data = [agent.to_list() for agent in self.agents] @@ -803,6 +894,14 @@ class RailEnv(Environment): return msgpack.packb(msg_data, use_bin_type=True) def save(self, filename, save_distance_maps=False): + """ + Saves environment and distance map information in a file + + Parameters: + --------- + filename: string + save_distance_maps: bool + """ if save_distance_maps is True: if self.distance_map.get() is not None: if len(self.distance_map.get()) > 0: @@ -819,14 +918,31 @@ class RailEnv(Environment): file_out.write(self.get_full_state_msg()) def load(self, filename): + """ + Load environment with distance map from a file + + Parameters: + ------- + filename: string + """ with open(filename, "rb") as file_in: load_data = file_in.read() self.set_full_state_dist_msg(load_data) def load_pkl(self, pkl_data): + """ + Load environment with distance map from a pickle file + + Parameters: + ------- + pkl_data: pickle file + """ self.set_full_state_msg(pkl_data) def load_resource(self, package, resource): + """ + Load environment with distance map from a binary + """ from importlib_resources import read_binary load_data = read_binary(package, resource) self.set_full_state_msg(load_data) -- GitLab