Skip to content
Snippets Groups Projects
Commit 9edc0ff8 authored by u229589's avatar u229589
Browse files

Merge branch 'master' of gitlab.aicrowd.com:flatland/flatland into...

Merge branch 'master' of gitlab.aicrowd.com:flatland/flatland into 259-reduce-computational-complexity-of-env-step

# Conflicts:
#	flatland/envs/rail_env.py
parent 1e8a2421
No related branches found
No related tags found
No related merge requests found
...@@ -84,6 +84,7 @@ class RailEnv(Environment): ...@@ -84,6 +84,7 @@ class RailEnv(Environment):
- invalid_action_penalty = 0 - invalid_action_penalty = 0
- step_penalty = -alpha - step_penalty = -alpha
- global_reward = beta - global_reward = beta
- epsilon = avoid rounding errors
- stop_penalty = 0 # penalty for stopping a moving agent - stop_penalty = 0 # penalty for stopping a moving agent
- start_penalty = 0 # penalty for starting a stopped agent - start_penalty = 0 # penalty for starting a stopped agent
...@@ -217,6 +218,8 @@ class RailEnv(Environment): ...@@ -217,6 +218,8 @@ class RailEnv(Environment):
# Reset environment # Reset environment
self.valid_positions = None self.valid_positions = None
# global numpy array of agents position, True means that there is an agent at that cell
self.agent_positions: np.ndarray = np.full((height, width), False) self.agent_positions: np.ndarray = np.full((height, width), False)
def _seed(self, seed=None): def _seed(self, seed=None):
...@@ -437,7 +440,14 @@ class RailEnv(Environment): ...@@ -437,7 +440,14 @@ class RailEnv(Environment):
return False return False
def step(self, action_dict_: Dict[int, RailEnvActions]): def step(self, action_dict_: Dict[int, RailEnvActions]):
"""
Updates rewards for the agents at a step.
Parameters
----------
action_dict_ : Dict[int,RailEnvActions]
"""
self._elapsed_steps += 1 self._elapsed_steps += 1
# If we're done, set reward and info_dict and step() is done. # If we're done, set reward and info_dict and step() is done.
...@@ -631,15 +641,41 @@ class RailEnv(Environment): ...@@ -631,15 +641,41 @@ class RailEnv(Environment):
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
def _set_agent_to_initial_position(self, agent: EnvAgent, new_position: IntVector2D): def _set_agent_to_initial_position(self, agent: EnvAgent, new_position: IntVector2D):
"""
Sets the agent to its initial position. Updates the agent object and the position
of the agent inside the global agent_position numpy array
Parameters
-------
agent: EnvAgent object
new_position: IntVector2D
"""
agent.position = new_position agent.position = new_position
self.agent_positions[agent.position] = True self.agent_positions[agent.position] = True
def _move_agent_to_new_position(self, agent: EnvAgent, new_position: IntVector2D): def _move_agent_to_new_position(self, agent: EnvAgent, new_position: IntVector2D):
"""
Move the agent to the a new position. Updates the agent object and the position
of the agent inside the global agent_position numpy array
Parameters
-------
agent: EnvAgent object
new_position: IntVector2D
"""
agent.position = new_position agent.position = new_position
self.agent_positions[agent.old_position] = False self.agent_positions[agent.old_position] = False
self.agent_positions[agent.position] = True self.agent_positions[agent.position] = True
def _remove_agent_from_scene(self, agent: EnvAgent): def _remove_agent_from_scene(self, agent: EnvAgent):
"""
Remove the agent from the scene. Updates the agent object and the position
of the agent inside the global agent_position numpy array
Parameters
-------
agent: EnvAgent object
"""
self.agent_positions[agent.position] = False self.agent_positions[agent.position] = False
if self.remove_agents_at_target: if self.remove_agents_at_target:
agent.position = None agent.position = None
...@@ -687,6 +723,19 @@ class RailEnv(Environment): ...@@ -687,6 +723,19 @@ class RailEnv(Environment):
return cell_free, new_cell_valid, new_direction, new_position, transition_valid return cell_free, new_cell_valid, new_direction, new_position, transition_valid
def cell_free(self, position: IntVector2D) -> bool: def cell_free(self, position: IntVector2D) -> bool:
"""
Utility to check if a cell is free
Parameters:
--------
position : Tuple[int, int]
Returns
-------
bool
is the cell free or not?
"""
try: try:
return not self.agent_positions[position] return not self.agent_positions[position]
except IndexError as error: except IndexError as error:
...@@ -734,13 +783,35 @@ class RailEnv(Environment): ...@@ -734,13 +783,35 @@ class RailEnv(Environment):
return new_direction, transition_valid return new_direction, transition_valid
def _get_observations(self): def _get_observations(self):
"""
Utility which returns the observations for an agent with respect to environment
Returns
------
Dict object
"""
self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents()))) self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
return self.obs_dict return self.obs_dict
def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]: def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]:
"""
Returns directions in which the agent can move
Parameters:
---------
row : int
col : int
Returns:
-------
List[int]
"""
return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col)) return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col))
def get_full_state_msg(self): def get_full_state_msg(self):
"""
Returns state of environment in msgpack object
"""
grid_data = self.rail.grid.tolist() grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static] agent_static_data = [agent.to_list() for agent in self.agents_static]
agent_data = [agent.to_list() for agent in self.agents] agent_data = [agent.to_list() for agent in self.agents]
...@@ -754,12 +825,22 @@ class RailEnv(Environment): ...@@ -754,12 +825,22 @@ class RailEnv(Environment):
return msgpack.packb(msg_data, use_bin_type=True) return msgpack.packb(msg_data, use_bin_type=True)
def get_agent_state_msg(self): def get_agent_state_msg(self):
"""
Returns agents information in msgpack object
"""
agent_data = [agent.to_list() for agent in self.agents] agent_data = [agent.to_list() for agent in self.agents]
msg_data = { msg_data = {
"agents": agent_data} "agents": agent_data}
return msgpack.packb(msg_data, use_bin_type=True) return msgpack.packb(msg_data, use_bin_type=True)
def set_full_state_msg(self, msg_data): def set_full_state_msg(self, msg_data):
"""
Sets environment state with msgdata object passed as argument
Parameters
-------
msg_data: msgpack object
"""
data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8') data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
self.rail.grid = np.array(data["grid"]) self.rail.grid = np.array(data["grid"])
# agents are always reset as not moving # agents are always reset as not moving
...@@ -772,6 +853,13 @@ class RailEnv(Environment): ...@@ -772,6 +853,13 @@ class RailEnv(Environment):
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
def set_full_state_dist_msg(self, msg_data): def set_full_state_dist_msg(self, msg_data):
"""
Sets environment grid state and distance map with msgdata object passed as argument
Parameters
-------
msg_data: msgpack object
"""
data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8') data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
self.rail.grid = np.array(data["grid"]) self.rail.grid = np.array(data["grid"])
# agents are always reset as not moving # agents are always reset as not moving
...@@ -786,6 +874,9 @@ class RailEnv(Environment): ...@@ -786,6 +874,9 @@ class RailEnv(Environment):
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
def get_full_state_dist_msg(self): def get_full_state_dist_msg(self):
"""
Returns environment information with distance map information as msgpack object
"""
grid_data = self.rail.grid.tolist() grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static] agent_static_data = [agent.to_list() for agent in self.agents_static]
agent_data = [agent.to_list() for agent in self.agents] agent_data = [agent.to_list() for agent in self.agents]
...@@ -803,6 +894,14 @@ class RailEnv(Environment): ...@@ -803,6 +894,14 @@ class RailEnv(Environment):
return msgpack.packb(msg_data, use_bin_type=True) return msgpack.packb(msg_data, use_bin_type=True)
def save(self, filename, save_distance_maps=False): def save(self, filename, save_distance_maps=False):
"""
Saves environment and distance map information in a file
Parameters:
---------
filename: string
save_distance_maps: bool
"""
if save_distance_maps is True: if save_distance_maps is True:
if self.distance_map.get() is not None: if self.distance_map.get() is not None:
if len(self.distance_map.get()) > 0: if len(self.distance_map.get()) > 0:
...@@ -819,14 +918,31 @@ class RailEnv(Environment): ...@@ -819,14 +918,31 @@ class RailEnv(Environment):
file_out.write(self.get_full_state_msg()) file_out.write(self.get_full_state_msg())
def load(self, filename): def load(self, filename):
"""
Load environment with distance map from a file
Parameters:
-------
filename: string
"""
with open(filename, "rb") as file_in: with open(filename, "rb") as file_in:
load_data = file_in.read() load_data = file_in.read()
self.set_full_state_dist_msg(load_data) self.set_full_state_dist_msg(load_data)
def load_pkl(self, pkl_data): def load_pkl(self, pkl_data):
"""
Load environment with distance map from a pickle file
Parameters:
-------
pkl_data: pickle file
"""
self.set_full_state_msg(pkl_data) self.set_full_state_msg(pkl_data)
def load_resource(self, package, resource): def load_resource(self, package, resource):
"""
Load environment with distance map from a binary
"""
from importlib_resources import read_binary from importlib_resources import read_binary
load_data = read_binary(package, resource) load_data = read_binary(package, resource)
self.set_full_state_msg(load_data) self.set_full_state_msg(load_data)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment