Commit beac6e15 authored by MasterScrat's avatar MasterScrat
Browse files

Trying to optimize observations, trying with tree depth of 1

parent e61019c4
......@@ -2,16 +2,12 @@
Collection of environment-specific ObservationBuilder.
"""
import collections
from typing import Optional, List, Dict, Tuple
from typing import Optional, List, Dict
import numpy as np
from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.grid.grid_utils import coordinate_to_position
from flatland.envs.agent_utils import RailAgentStatus, EnvAgent
from flatland.envs.agent_utils import RailAgentStatus
from flatland.utils.ordered_set import OrderedSet
Node = collections.namedtuple('Node', 'dist_own_target_encountered '
......@@ -29,7 +25,7 @@ Node = collections.namedtuple('Node', 'dist_own_target_encountered '
'childs')
class TreeObsForRailEnv(ObservationBuilder):
class TreeObsForRailEnv():
"""
TreeObsForRailEnv object.
......@@ -42,7 +38,7 @@ class TreeObsForRailEnv(ObservationBuilder):
tree_explored_actions_char = ['L', 'F', 'R', 'B']
def __init__(self, max_depth: int, predictor: PredictionBuilder = None):
def __init__(self, max_depth: int, predictor=None):
super().__init__()
self.max_depth = max_depth
self.observation_dim = 11
......@@ -59,6 +55,18 @@ class TreeObsForRailEnv(ObservationBuilder):
Called whenever an observation has to be computed for the `env` environment, for each agent with handle
in the `handles` list.
"""
observations = {}
if handles is None:
handles = []
for h in handles:
observations[h] = self.get(h)
return observations
def prepare_get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Node]:
"""
Called whenever an observation has to be computed for the `env` environment, for each agent with handle
in the `handles` list.
"""
if handles is None:
handles = []
......@@ -99,14 +107,10 @@ class TreeObsForRailEnv(ObservationBuilder):
'malfunction']
if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \
_agent.initial_position:
_agent.initial_position:
self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \
self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1
observations = super().get_many(handles)
return observations
def get(self, handle: int = 0) -> Node:
"""
Computes the current observation for agent `handle` in env
......@@ -340,8 +344,8 @@ class TreeObsForRailEnv(ObservationBuilder):
conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
for ca in conflicting_agent[0]:
if direction != self.predicted_dir[pre_step][ca] \
and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \
and tot_dist < potential_conflict: # noqa: E125
and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \
and tot_dist < potential_conflict: # noqa: E125
potential_conflict = tot_dist
if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
potential_conflict = tot_dist
......@@ -351,8 +355,8 @@ class TreeObsForRailEnv(ObservationBuilder):
conflicting_agent = np.where(self.predicted_pos[post_step] == int_position)
for ca in conflicting_agent[0]:
if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir(
self.predicted_dir[post_step][ca])] == 1 \
and tot_dist < potential_conflict: # noqa: E125
self.predicted_dir[post_step][ca])] == 1 \
and tot_dist < potential_conflict: # noqa: E125
potential_conflict = tot_dist
if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
potential_conflict = tot_dist
......@@ -513,231 +517,10 @@ class TreeObsForRailEnv(ObservationBuilder):
for direction in self.tree_explored_actions_char:
self.print_subtree(node.childs[direction], direction, indent + "\t")
def set_env(self, env: Environment):
super().set_env(env)
def set_env(self, env):
self.env = env
if self.predictor:
self.predictor.set_env(self.env)
def _reverse_dir(self, direction):
return int((direction + 2) % 4)
class GlobalObsForRailEnv(ObservationBuilder):
"""
Gives a global observation of the entire rail environment.
The observation is composed of the following elements:
- transition map array with dimensions (env.height, env.width, 16),\
assuming 16 bits encoding of transitions.
- obs_agents_state: A 3D array (map_height, map_width, 5) with
- first channel containing the agents position and direction
- second channel containing the other agents positions and direction
- third channel containing agent/other agent malfunctions
- fourth channel containing agent/other agent fractional speeds
- fifth channel containing number of other agents ready to depart
- obs_targets: Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\
target and the positions of the other agents targets (flag only, no counter!).
"""
def __init__(self):
super(GlobalObsForRailEnv, self).__init__()
def set_env(self, env: Environment):
super().set_env(env)
def reset(self):
self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
for i in range(self.rail_obs.shape[0]):
for j in range(self.rail_obs.shape[1]):
bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
bitlist = [0] * (16 - len(bitlist)) + bitlist
self.rail_obs[i, j] = np.array(bitlist)
def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
agent = self.env.agents[handle]
if agent.status == RailAgentStatus.READY_TO_DEPART:
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
agent_virtual_position = agent.position
elif agent.status == RailAgentStatus.DONE:
agent_virtual_position = agent.target
else:
return None
obs_targets = np.zeros((self.env.height, self.env.width, 2))
obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - 1
# TODO can we do this more elegantly?
# for r in range(self.env.height):
# for c in range(self.env.width):
# obs_agents_state[(r, c)][4] = 0
obs_agents_state[:, :, 4] = 0
obs_agents_state[agent_virtual_position][0] = agent.direction
obs_targets[agent.target][0] = 1
for i in range(len(self.env.agents)):
other_agent: EnvAgent = self.env.agents[i]
# ignore other agents not in the grid any more
if other_agent.status == RailAgentStatus.DONE_REMOVED:
continue
obs_targets[other_agent.target][1] = 1
# second to fourth channel only if in the grid
if other_agent.position is not None:
# second channel only for other agents
if i != handle:
obs_agents_state[other_agent.position][1] = other_agent.direction
obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction']
obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed']
# fifth channel: all ready to depart on this position
if other_agent.status == RailAgentStatus.READY_TO_DEPART:
obs_agents_state[other_agent.initial_position][4] += 1
return self.rail_obs, obs_agents_state, obs_targets
class LocalObsForRailEnv(ObservationBuilder):
"""
!!!!!!WARNING!!! THIS IS DEPRACTED AND NOT UPDATED TO FLATLAND 2.0!!!!!
Gives a local observation of the rail environment around the agent.
The observation is composed of the following elements:
- transition map array of the local environment around the given agent, \
with dimensions (view_height,2*view_width+1, 16), \
assuming 16 bits encoding of transitions.
- Two 2D arrays (view_height,2*view_width+1, 2) containing respectively, \
if they are in the agent's vision range, its target position, the positions of the other targets.
- A 2D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions \
of the other agents at their position coordinates, if they are in the agent's vision range.
- A 4 elements array with one hot encoding of the direction.
Use the parameters view_width and view_height to define the rectangular view of the agent.
The center parameters moves the agent along the height axis of this rectangle. If it is 0 the agent only has
observation in front of it.
.. deprecated:: 2.0.0
"""
def __init__(self, view_width, view_height, center):
super(LocalObsForRailEnv, self).__init__()
self.view_width = view_width
self.view_height = view_height
self.center = center
self.max_padding = max(self.view_width, self.view_height - self.center)
def reset(self):
# We build the transition map with a view_radius empty cells expansion on each side.
# This helps to collect the local transition map view when the agent is close to a border.
self.max_padding = max(self.view_width, self.view_height)
self.rail_obs = np.zeros((self.env.height,
self.env.width, 16))
for i in range(self.env.height):
for j in range(self.env.width):
bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
bitlist = [0] * (16 - len(bitlist)) + bitlist
self.rail_obs[i, j] = np.array(bitlist)
def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray):
agents = self.env.agents
agent = agents[handle]
# Correct agents position for padding
# agent_rel_pos[0] = agent.position[0] + self.max_padding
# agent_rel_pos[1] = agent.position[1] + self.max_padding
# Collect visible cells as set to be plotted
visited, rel_coords = self.field_of_view(agent.position, agent.direction, )
local_rail_obs = None
# Add the visible cells to the observed cells
self.env.dev_obs_dict[handle] = set(visited)
# Locate observed agents and their coresponding targets
local_rail_obs = np.zeros((self.view_height, 2 * self.view_width + 1, 16))
obs_map_state = np.zeros((self.view_height, 2 * self.view_width + 1, 2))
obs_other_agents_state = np.zeros((self.view_height, 2 * self.view_width + 1, 4))
_idx = 0
for pos in visited:
curr_rel_coord = rel_coords[_idx]
local_rail_obs[curr_rel_coord[0], curr_rel_coord[1], :] = self.rail_obs[pos[0], pos[1], :]
if pos == agent.target:
obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 0] = 1
else:
for tmp_agent in agents:
if pos == tmp_agent.target:
obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 1] = 1
if pos != agent.position:
for tmp_agent in agents:
if pos == tmp_agent.position:
obs_other_agents_state[curr_rel_coord[0], curr_rel_coord[1], :] = np.identity(4)[
tmp_agent.direction]
_idx += 1
direction = np.identity(4)[agent.direction]
return local_rail_obs, obs_map_state, obs_other_agents_state, direction
def get_many(self, handles: Optional[List[int]] = None) -> Dict[
int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]:
"""
Called whenever an observation has to be computed for the `env` environment, for each agent with handle
in the `handles` list.
"""
return super().get_many(handles)
def field_of_view(self, position, direction, state=None):
# Compute the local field of view for an agent in the environment
data_collection = False
if state is not None:
temp_visible_data = np.zeros(shape=(self.view_height, 2 * self.view_width + 1, 16))
data_collection = True
if direction == 0:
origin = (position[0] + self.center, position[1] - self.view_width)
elif direction == 1:
origin = (position[0] - self.view_width, position[1] - self.center)
elif direction == 2:
origin = (position[0] - self.center, position[1] + self.view_width)
else:
origin = (position[0] + self.view_width, position[1] + self.center)
visible = list()
rel_coords = list()
for h in range(self.view_height):
for w in range(2 * self.view_width + 1):
if direction == 0:
if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
visible.append((origin[0] - h, origin[1] + w))
rel_coords.append((h, w))
# if data_collection:
# temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :]
elif direction == 1:
if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width:
visible.append((origin[0] + w, origin[1] + h))
rel_coords.append((h, w))
# if data_collection:
# temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :]
elif direction == 2:
if 0 <= origin[0] + h < self.env.height and 0 <= origin[1] - w < self.env.width:
visible.append((origin[0] + h, origin[1] - w))
rel_coords.append((h, w))
# if data_collection:
# temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :]
else:
if 0 <= origin[0] - w < self.env.height and 0 <= origin[1] - h < self.env.width:
visible.append((origin[0] - w, origin[1] - h))
rel_coords.append((h, w))
# if data_collection:
# temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :]
if data_collection:
return temp_visible_data
else:
return visible, rel_coords
......@@ -4,8 +4,9 @@ from multiprocessing.pool import Pool
from pathlib import Path
import numpy as np
import torch
import time
from flatland.core.env_observation_builder import DummyObservationBuilder
from flatland.evaluators.client import FlatlandRemoteClient
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
......@@ -49,7 +50,7 @@ if __name__ == "__main__":
}
# Observation parameters
observation_tree_depth = 2
observation_tree_depth = 1
observation_radius = 10
observation_max_path_depth = 30
......@@ -58,15 +59,14 @@ if __name__ == "__main__":
tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth, predictor=predictor)
num_features_per_node = tree_observation.observation_dim
tree_depth = 2
nr_nodes = 0
for i in range(tree_depth + 1):
for i in range(observation_tree_depth + 1):
nr_nodes += np.power(4, i)
state_size = num_features_per_node * nr_nodes
action_size = 5
policy = DDDQNPolicy(state_size, action_size, Namespace(**parameters), evaluation_mode=True)
#policy.qnetwork_local = torch.load(checkpoint)
# policy.qnetwork_local = torch.load(checkpoint)
# Controller
pool = Pool()
......@@ -83,10 +83,18 @@ if __name__ == "__main__":
time_start = time.time()
observation, info = remote_client.env_create(
obs_builder_object=tree_observation
obs_builder_object=DummyObservationBuilder()
)
env_creation_time = time.time() - time_start
local_env = remote_client.env
number_of_agents = len(local_env.agents)
tree_observation.set_env(local_env)
tree_observation.reset()
tree_observation.prepare_get_many(list(range(number_of_agents)))
observation = tree_observation.get_many(list(range(number_of_agents)))
if not observation:
#
# If the remote_client returns False on a `env_create` call,
......@@ -97,9 +105,6 @@ if __name__ == "__main__":
print("Evaluation Number : {}".format(evaluation_number))
local_env = remote_client.env
number_of_agents = len(local_env.agents)
# Now we enter into another infinite loop where we
# compute the actions for all the individual steps in this episode
# until the episode is `done`
......@@ -117,15 +122,26 @@ if __name__ == "__main__":
time_taken_by_controller.append(agent_time)
time_start = time.time()
observation, all_rewards, done, info = remote_client.env_step(action)
_, all_rewards, done, info = remote_client.env_step(action)
steps += 1
step_time = time.time() - time_start
time_taken_per_step.append(step_time)
print("Step {}\t Agent time {:.3f}\t Step time {:.3f}".format(str(steps).zfill(3), agent_time, step_time))
time_start = time.time()
tree_observation.prepare_get_many(list(range(number_of_agents)))
prepare_time = time.time() - time_start
time_start = time.time()
observation_list = []
for h in range(number_of_agents):
observation_list.append(tree_observation.get(h))
observation = dict(zip(range(number_of_agents), observation_list))
obs_time = time.time() - time_start
# print("Step {}\t Prepare time {:.3f}\t Obs time {:.3f}\t Inference time {:.3f}\t Step time {:.3f}".format(str(steps).zfill(3), prepare_time, obs_time, agent_time, step_time))
if check_if_all_blocked(local_env):
print("DEADLOCKED!!")
# if check_if_all_blocked(local_env):
# print("DEADLOCKED!!")
if done['__all__']:
print("Reward : ", sum(list(all_rewards.values())))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment