diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index cb4da95a81758625f1f45ecb535e47d834fd5005..ad160e4e8770b9994052563c4584a7211e8da3bc 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -83,6 +83,7 @@ class RailEnv(Environment): - invalid_action_penalty = 0 - step_penalty = -alpha - global_reward = beta + - epsilon = avoid rounding errors - stop_penalty = 0 # penalty for stopping a moving agent - start_penalty = 0 # penalty for starting a stopped agent @@ -433,7 +434,14 @@ class RailEnv(Environment): return False def step(self, action_dict_: Dict[int, RailEnvActions]): + """ + Updates rewards for the agents at a step. + Parameters + ---------- + action_dict_ : Dict[int,RailEnvActions] + + """ self._elapsed_steps += 1 # If we're done, set reward and info_dict and step() is done. @@ -671,7 +679,19 @@ class RailEnv(Environment): return cell_free, new_cell_valid, new_direction, new_position, transition_valid def cell_free(self, position): + """ + Utility to check if a cell is free + Parameters: + -------- + position : Tuple[int, int] + + Returns + ------- + bool + is the cell free or not? + + """ agent_positions = [agent.position for agent in self.agents if agent.position is not None] ret = len(agent_positions) == 0 or not np.any(np.equal(position, agent_positions).all(1)) return ret @@ -717,13 +737,35 @@ class RailEnv(Environment): return new_direction, transition_valid def _get_observations(self): + """ + Utility which returns the observations for an agent with respect to environment + + Returns + ------ + Dict object + """ self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents()))) return self.obs_dict def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]: + """ + Returns directions in which the agent can move + + Parameters: + --------- + row : int + col : int + + Returns: + ------- + List[int] + """ return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col)) def get_full_state_msg(self): + """ + Returns state of environment in msgpack object + """ grid_data = self.rail.grid.tolist() agent_static_data = [agent.to_list() for agent in self.agents_static] agent_data = [agent.to_list() for agent in self.agents] @@ -737,12 +779,22 @@ class RailEnv(Environment): return msgpack.packb(msg_data, use_bin_type=True) def get_agent_state_msg(self): + """ + Returns agents information in msgpack object + """ agent_data = [agent.to_list() for agent in self.agents] msg_data = { "agents": agent_data} return msgpack.packb(msg_data, use_bin_type=True) def set_full_state_msg(self, msg_data): + """ + Sets environment state with msgdata object passed as argument + + Parameters + ------- + msg_data: msgpack object + """ data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8') self.rail.grid = np.array(data["grid"]) # agents are always reset as not moving @@ -755,6 +807,13 @@ class RailEnv(Environment): self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) def set_full_state_dist_msg(self, msg_data): + """ + Sets environment grid state and distance map with msgdata object passed as argument + + Parameters + ------- + msg_data: msgpack object + """ data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8') self.rail.grid = np.array(data["grid"]) # agents are always reset as not moving @@ -769,6 +828,9 @@ class RailEnv(Environment): self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) def get_full_state_dist_msg(self): + """ + Returns environment information with distance map information as msgpack object + """ grid_data = self.rail.grid.tolist() agent_static_data = [agent.to_list() for agent in self.agents_static] agent_data = [agent.to_list() for agent in self.agents] @@ -786,6 +848,14 @@ class RailEnv(Environment): return msgpack.packb(msg_data, use_bin_type=True) def save(self, filename, save_distance_maps=False): + """ + Saves environment and distance map information in a file + + Parameters: + --------- + filename: string + save_distance_maps: bool + """ if save_distance_maps is True: if self.distance_map.get() is not None: if len(self.distance_map.get()) > 0: @@ -802,14 +872,31 @@ class RailEnv(Environment): file_out.write(self.get_full_state_msg()) def load(self, filename): + """ + Load environment with distance map from a file + + Parameters: + ------- + filename: string + """ with open(filename, "rb") as file_in: load_data = file_in.read() self.set_full_state_dist_msg(load_data) def load_pkl(self, pkl_data): + """ + Load environment with distance map from a pickle file + + Parameters: + ------- + pkl_data: pickle file + """ self.set_full_state_msg(pkl_data) def load_resource(self, package, resource): + """ + Load environment with distance map from a binary + """ from importlib_resources import read_binary load_data = read_binary(package, resource) self.set_full_state_msg(load_data) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 028928aeec2c08c7b9035d875a874987dce79896..3e90128c74bda860d8a4c75d71652071660482a3 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -343,6 +343,17 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11, seed=1) -> R template = [template[-1]] + template[:-1] def get_matching_templates(template): + """ + Returns a list of possible transition maps for a given template + + Parameters: + ------ + template:List[int] + + Returns: + ------ + List[int] + """ ret = [] for i in range(len(transitions_templates_)): is_match = True diff --git a/notebooks/simple_example1_env_from_tuple.ipynb b/notebooks/simple_example1_env_from_tuple.ipynb index 2d13585156b447fe4f9eed07834b3dc04ccf84d6..317a5b63ef171691c0324d38d82ad543e295155c 100644 --- a/notebooks/simple_example1_env_from_tuple.ipynb +++ b/notebooks/simple_example1_env_from_tuple.ipynb @@ -10,9 +10,21 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "SystemError", + "evalue": "Parent module '' not loaded, cannot perform relative import", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mSystemError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m<ipython-input-2-b6a25a9cfbbb>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflatland\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrail_generators\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mrail_from_manual_specifications_generator\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mflatland\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobservations\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mTreeObsForRailEnv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mflatland\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrail_env\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mRailEnv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mflatland\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrendertools\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mRenderTool\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mPIL\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mImage\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mSystemError\u001b[0m: Parent module '' not loaded, cannot perform relative import" + ] + } + ], "source": [ "from flatland.envs.rail_generators import rail_from_manual_specifications_generator\n", "from flatland.envs.observations import TreeObsForRailEnv\n", @@ -24,7 +36,9 @@ { "cell_type": "code", "execution_count": 2, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)],\n", @@ -83,7 +97,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.3" + "version": "3.5.2" }, "latex_envs": { "LaTeX_envs_menu_present": true, diff --git a/notebooks/simple_example2_generate_random_rail.ipynb b/notebooks/simple_example2_generate_random_rail.ipynb index 19b854ee15d8dd1e19361f58a552eba617b19b67..bfa2c877ef0ecec8c596592c13c08db287661c78 100644 --- a/notebooks/simple_example2_generate_random_rail.ipynb +++ b/notebooks/simple_example2_generate_random_rail.ipynb @@ -88,7 +88,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.5" + "version": "3.5.2" }, "latex_envs": { "LaTeX_envs_menu_present": true,