Commit 9fa72a6e authored by u214892's avatar u214892
Browse files

observations only for active agents

parent 61b75d2f
Pipeline #2332 passed with stages
in 60 minutes and 13 seconds
......@@ -91,17 +91,3 @@ class Environment:
function.
"""
raise NotImplementedError()
def is_active_handle(self,h):
"""
Is the agent active and thus observable?
Parameters
----------
h: int agent handle
Returns
-------
"""
return True
......@@ -52,8 +52,7 @@ class ObservationBuilder:
if handles is None:
handles = []
for h in handles:
if self.env.is_active_handle(h):
observations[h] = self.get(h)
observations[h] = self.get(h)
return observations
def get(self, handle: int = 0):
......
......@@ -9,10 +9,10 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum
class RailAgentStatus(IntEnum):
READY_TO_DEPART = 0 # -> observation
ACTIVE = 1 # -> observation
DONE = 2 # -> observation
DONE_REMOVED = 3 # -> no observation
READY_TO_DEPART = 0 # not in grid yet (position is None) -> prediction as if it were at initial position
ACTIVE = 1 # in grid (position is not None), not done -> prediction is remaining path
DONE = 2 # in grid (position is not None), but done -> prediction is stay at target forever
DONE_REMOVED = 3 # removed from grid (position is None) -> prediction is None
@attrs
......
......@@ -67,12 +67,12 @@ class TreeObsForRailEnv(ObservationBuilder):
self.predicted_dir = {}
self.predictions = self.predictor.get()
if self.predictions:
# TODO hacky hacky: `range(len(self.predictions[0]))` does not seem safe!!
for t in range(len(self.predictions[0])):
pos_list = []
dir_list = []
for a in handles:
if self.env.agents[a].status != RailAgentStatus.ACTIVE:
if self.predictions[a] is None:
continue
pos_list.append(self.predictions[a][t][1:3])
dir_list.append(self.predictions[a][t][3])
......@@ -164,21 +164,41 @@ class TreeObsForRailEnv(ObservationBuilder):
"""
# Update local lookup table for all agents' positions
self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents}
# ignore other agents not in the grid (only status active and done)
self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if
agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]}
self.location_has_agent_ready_to_depart = {}
for agent in self.env.agents:
if agent.status == RailAgentStatus.READY_TO_DEPART:
self.location_has_agent_ready_to_depart = \
self.location_has_agent_ready_to_depart[tuple(agent.initial_position)] = \
self.location_has_agent_ready_to_depart.get(tuple(agent.initial_position), 0) + 1
self.location_has_agent_direction = {tuple(agent.position): agent.direction for agent in self.env.agents}
self.location_has_agent_speed = {tuple(agent.position): agent.speed_data['speed'] for agent in self.env.agents}
self.location_has_agent_malfunction = {tuple(agent.position): agent.malfunction_data['malfunction'] for agent in
self.env.agents}
self.location_has_agent_direction = {
tuple(agent.position): agent.direction
for agent in self.env.agents if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]
}
self.location_has_agent_speed = {
tuple(agent.position): agent.speed_data['speed']
for agent in self.env.agents if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]
}
self.location_has_agent_malfunction = {
tuple(agent.position): agent.malfunction_data['malfunction']
for agent in self.env.agents if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE]
}
if handle > len(self.env.agents):
print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
agent = self.env.agents[handle] # TODO: handle being treated as index
possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
if agent.status == RailAgentStatus.READY_TO_DEPART:
_agent_initial_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
_agent_initial_position = agent.position
elif agent.status == RailAgentStatus.DONE:
_agent_initial_position = agent.target
else:
return None
possible_transitions = self.env.rail.get_transitions(*_agent_initial_position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
# Here information about the agent itself is stored
......@@ -187,8 +207,9 @@ class TreeObsForRailEnv(ObservationBuilder):
root_node_observation = TreeObsForRailEnv.Node(dist_own_target_encountered=0, dist_other_target_encountered=0,
dist_other_agent_encountered=0, dist_potential_conflict=0,
dist_unusable_switch=0, dist_to_next_branch=0,
dist_min_to_target=distance_map[(handle, *agent.position,
agent.direction)],
dist_min_to_target=distance_map[
(handle, *_agent_initial_position,
agent.direction)],
num_agents_same_direction=0, num_agents_opposite_direction=0,
num_agents_malfunctioning=agent.malfunction_data['malfunction'],
speed_min_fractional=agent.speed_data['speed'],
......@@ -208,7 +229,7 @@ class TreeObsForRailEnv(ObservationBuilder):
for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]):
if possible_transitions[branch_direction]:
new_cell = get_new_position(agent.position, branch_direction)
new_cell = get_new_position(_agent_initial_position, branch_direction)
branch_observation, branch_visited = \
self._explore_branch(handle, new_cell, branch_direction, 1, 1)
......@@ -534,15 +555,27 @@ class GlobalObsForRailEnv(ObservationBuilder):
def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
agent = self.env.agents[handle]
if agent.status == RailAgentStatus.READY_TO_DEPART:
_agent_initial_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
_agent_initial_position = agent.position
elif agent.status == RailAgentStatus.DONE:
_agent_initial_position = agent.target
else:
return None
obs_targets = np.zeros((self.env.height, self.env.width, 2))
obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - 1
agent = self.env.agents[handle]
obs_agents_state[agent.position][0] = agent.direction
obs_agents_state[_agent_initial_position][0] = agent.direction
obs_targets[agent.target][0] = 1
for i in range(len(self.env.agents)):
other_agent: EnvAgent = self.env.agents[i]
# ignore other_agent if it is not in the grid
if other_agent.position is None:
continue
if i != handle:
obs_agents_state[other_agent.position][1] = other_agent.direction
obs_targets[other_agent.target][1] = 1
......
......@@ -126,9 +126,17 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
prediction_dict = {}
for agent in agents:
if agent.status != RailAgentStatus.ACTIVE:
if agent.status == RailAgentStatus.READY_TO_DEPART:
_agent_initial_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
_agent_initial_position = agent.position
elif agent.status == RailAgentStatus.DONE:
_agent_initial_position = agent.target
else:
prediction_dict[agent.handle] = None
continue
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
agent_speed = agent.speed_data["speed"]
times_per_cell = int(np.reciprocal(agent_speed))
......
......@@ -525,9 +525,6 @@ class RailEnv(Environment):
# step penalty if not moving (stopped now or before)
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
def is_active_handle(self, h):
return self.agents[h].status == RailAgentStatus.ACTIVE
def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent):
"""
......
......@@ -98,12 +98,10 @@ def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = Non
elif agent.status == RailAgentStatus.ACTIVE:
position = agent.position
elif agent.status == RailAgentStatus.DONE:
if agent.position is not None:
position = agent.target
else:
shortest_paths[agent.handle] = None
return
# todo is this correct? current position?
position = agent.target
else:
shortest_paths[agent.handle] = None
return
direction = agent.direction
shortest_paths[agent.handle] = []
distance = math.inf
......
......@@ -146,6 +146,9 @@ class RenderTool(object):
Plot a simple agent.
Assumes a working graphics layer context (cf a MPL figure).
"""
if position_row_col is None:
return
rt = self.__class__
direction_row_col = rt.transitions_row_col[direction] # agent direction in RC
......@@ -535,7 +538,7 @@ class RenderTool(object):
for agent_idx, agent in enumerate(self.env.agents):
if agent is None:
if agent is None or agent.position is None:
continue
if self.agent_render_variant == AgentRenderVariant.BOX_ONLY:
......
......@@ -30,7 +30,16 @@ class SingleAgentNavigationObs(ObservationBuilder):
def get(self, handle: int = 0) -> List[int]:
agent = self.env.agents[handle]
possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
if agent.status == RailAgentStatus.READY_TO_DEPART:
_agent_initial_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
_agent_initial_position = agent.position
elif agent.status == RailAgentStatus.DONE:
_agent_initial_position = agent.target
else:
return None
possible_transitions = self.env.rail.get_transitions(*_agent_initial_position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
......@@ -42,7 +51,7 @@ class SingleAgentNavigationObs(ObservationBuilder):
min_distances = []
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
new_position = get_new_position(agent.position, direction)
new_position = get_new_position(_agent_initial_position, direction)
min_distances.append(
self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment