Commit 7effaf44 authored by nilabha's avatar nilabha

Added Typing and some documentation

parent 7ae8390f
from typing import Optional, List, Dict
from typing import Optional, List, Dict, Union, Tuple
import gym
import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.utils.ordered_set import OrderedSet
from envs.flatland.observations import Observation, register_obs # noqa
from itertools import combinations
......@@ -14,27 +15,6 @@ from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.envs.agent_utils import RailAgentStatus
from flatland.core.grid.grid4_utils import get_new_position
# from flatland.envs.rail_env import action_required
def action_required(agent):
"""
Check if an agent needs to provide an action
Parameters
----------
agent: RailEnvAgent
Agent we want to check
Returns
-------
True: Agent needs to provide an action
False: Agent cannot provide an action
"""
return (agent.status == RailAgentStatus.READY_TO_DEPART or (
agent.status == RailAgentStatus.ACTIVE and
np.isclose(agent.speed_data['position_fraction'], 0.0,
rtol=1e-03)))
@register_obs("localConflict")
......@@ -46,7 +26,8 @@ class LocalConflictObservation(Observation):
LocalConflictObsForRailEnv(
max_depth=config['max_depth'],
predictor=ShortestPathPredictorForRailEnv(
config['shortest_path_max_depth']))
config['shortest_path_max_depth']),
n_local=config['n_local'])
)
def builder(self) -> ObservationBuilder:
......@@ -59,10 +40,21 @@ class LocalConflictObservation(Observation):
class LocalConflictObsForRailEnvRLLibWrapper(ObservationBuilder):
"""
The information is for each agent but uses the full set of
observations for all agents to come up with set of local
(Default: 5) most conflicting agents.
The observation set is based on the current agent and these local
identified agents. We also information about conflicts.
"""
def __init__(self, local_conflict_obs_builder: TreeObsForRailEnv):
super().__init__()
self._builder = local_conflict_obs_builder
self.agent_states = None
# To cache calculated agent states
# This is only computed once and reused for all other agents
self.agent_states: Optional[Dict] = None
@property
def observation_dim(self):
......@@ -87,24 +79,16 @@ class LocalConflictObsForRailEnvRLLibWrapper(ObservationBuilder):
def get_many(self, handles: Optional[List[int]] = None):
all_agent_observations = self._builder.get_many(handles)
o = dict()
obs = dict()
if handles is None:
handles = []
for k in handles:
if not self.agent_states:
self.agent_states = create_agent_states(
all_agent_observations, self._builder.predictor.max_depth)
o[k] = self.agent_states[k]
return o
# return {k: create_agent_states(o, self._builder.max_depth)
# for k, o in self._builder.get_many(handles).items()
# if o is not None}
obs[k] = self.agent_states[k]
# def util_print_obs_subtree(self, tree):
# self._builder.util_print_obs_subtree(tree)
# def print_subtree(self, node, label, indent):
# self._builder.print_subtree(node, label, indent)
return obs
def set_env(self, env):
self._builder.set_env(env)
......@@ -112,7 +96,14 @@ class LocalConflictObsForRailEnvRLLibWrapper(ObservationBuilder):
class LocalConflictObsForRailEnv(TreeObsForRailEnv):
"""
LocalConflict object made from TreeObsForRailEnv object.
This object returns observation vectors for agents in the RailEnv.
For details about the features in the observation
see the get() function.
We normalise all observations based on the grid size
"""
Node = collections.namedtuple('Node', 'distance_target '
'observation_shortest '
......@@ -136,7 +127,7 @@ class LocalConflictObsForRailEnv(TreeObsForRailEnv):
'predicted_pos')
def __init__(self, max_depth: int, predictor: PredictionBuilder = None,
n_local=5):
n_local: int = 5):
super().__init__(max_depth, predictor)
self.observation_dim = 1 + 3 * (n_local - 1) + 22 * n_local
......@@ -145,14 +136,7 @@ class LocalConflictObsForRailEnv(TreeObsForRailEnv):
def get_many(self, handles: Optional[List[int]] = None):
# observations = {}
# if handles is None:
# handles = []
# for h in handles:
# observations[h] = self.get(h)
# return observations
observations = super().get_many(handles)
# observations = list(observations.values())
return observations
def get(self, handle: int = 0):
......@@ -186,7 +170,7 @@ class LocalConflictObsForRailEnv(TreeObsForRailEnv):
max_distance = self.env.width + self.env.height
# max_steps = int(4 * 2 * (20 + self.env.height + self.env.width))
visited = set()
visited = OrderedSet()
for _idx in range(10):
# Check if any of the other prediction overlap
# with agents own predictions
......@@ -200,6 +184,8 @@ class LocalConflictObsForRailEnv(TreeObsForRailEnv):
# visualize the observation
self.env.dev_obs_dict[handle] = visited
# min_distance stores the distance to target in shortest path
# and any alternate path if exists
min_distances = []
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
......@@ -266,16 +252,18 @@ class LocalConflictObsForRailEnv(TreeObsForRailEnv):
return self.env.get_num_agents()
def create_agent_states(obs,
max_depth: int, n_local: int = 5):
def create_agent_states(obs: Union[Dict, List],
max_depth: int, n_local: int = 5) -> Dict:
"""
Identifies local agent conflicts and adds information from
conflict prediction matrix.
"""
n_agents = len(obs)
x_dim = 0
y_dim = 0
print(" N Agents:", n_agents)
for i in range(n_agents):
if obs[i] is not None:
custom_observations = obs[i]
# n_agents, x_dim, y_dim = custom_observations.n_agents,
x_dim = custom_observations.width
y_dim = custom_observations.height
break
......@@ -315,22 +303,21 @@ def create_agent_states(obs,
info_action_required[i] = int(custom_observations.action_required)
predicted_pos = custom_observations.predicted_pos
agent_conflicts_count_path, agent_conflicts_step_path, agent_total_step_conflicts = get_agent_conflict_prediction_matrix(
n_agents, max_depth, predicted_pos)
agent_conflicts_count_path, agent_conflicts_step_path,\
agent_total_step_conflicts = get_agent_conflict_prediction_matrix(
n_agents, max_depth, predicted_pos)
# Normalise based on average grid dimensions
avg_dim = (x_dim * y_dim) ** 0.5
depth = int(n_local * avg_dim / n_agents)
agent_conflict_steps = min(max_depth - 1, depth)
agent_conflicts = agent_conflicts_step_path[agent_conflict_steps]
# agent_counts = agent_conflicts_count_path[agent_conflict_steps]
agent_conflicts_avg_step_count = np.average(
agent_total_step_conflicts) / n_agents
for i in range(n_agents):
# if obs is None or obs[i] is None:
# # action_dict.update({i: 2})
if obs[i] is not None:
n_upd_local = min(n_local, n_agents - 1)
if n_upd_local < n_local:
......@@ -393,7 +380,8 @@ def create_agent_states(obs,
return local_agent_states_all
def get_agent_conflict_prediction_matrix(n_agents, max_depth, predicted_pos):
def get_agent_conflict_prediction_matrix(n_agents, max_depth, predicted_pos
) -> Tuple[List, List, List]:
agent_total_step_conflicts = []
agent_conflicts_step_path = []
agent_conflicts_count_path = []
......@@ -439,4 +427,25 @@ def get_agent_conflict_prediction_matrix(n_agents, max_depth, predicted_pos):
agent_total_step_conflicts.append(
sum(agent_conflicts_step_current[i, :]))
return agent_conflicts_count_path, agent_conflicts_step_path, agent_total_step_conflicts
return agent_conflicts_count_path, agent_conflicts_step_path,\
agent_total_step_conflicts
def action_required(agent):
"""
Check if an agent needs to provide an action
Parameters
----------
agent: RailEnvAgent
Agent we want to check
Returns
-------
True: Agent needs to provide an action
False: Agent cannot provide an action
"""
return (agent.status == RailAgentStatus.READY_TO_DEPART or (
agent.status == RailAgentStatus.ACTIVE and
np.isclose(agent.speed_data['position_fraction'], 0.0,
rtol=1e-03)))
......@@ -37,6 +37,7 @@ flatland-random-sparse-small-local-conflict-fc-ppo:
observation_config:
max_depth: 2
shortest_path_max_depth: 30
n_local: 5
regenerate_rail_on_reset: True
regenerate_schedule_on_reset: True
render: False
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment