Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • flatland/flatland
  • stefan_otte/flatland
  • jiaodaxiaozi/flatland
  • sfwatergit/flatland
  • utozx126/flatland
  • ChenKuanSun/flatland
  • ashivani/flatland
  • minhhoa/flatland
  • pranjal_dhole/flatland
  • darthgera123/flatland
  • rivesunder/flatland
  • thomaslecat/flatland
  • joel_joseph/flatland
  • kchour/flatland
  • alex_zharichenko/flatland
  • yoogottamk/flatland
  • troye_fang/flatland
  • elrichgro/flatland
  • jun_jin/flatland
  • nimishsantosh107/flatland
20 results
Show changes
Showing
with 2033 additions and 1685 deletions
from typing import Tuple
# Adrian Egli / Michel Marti performance fix (the fast methods brings more than 50%)
def fast_isclose(a, b, rtol):
return (a < (b + rtol)) or (a < (b - rtol))
def fast_clip(position: Tuple[int, int], min_value: Tuple[int, int], max_value: Tuple[int, int]) -> bool:
return (
max(min_value[0], min(position[0], max_value[0])),
max(min_value[1], min(position[1], max_value[1]))
)
def fast_argmax(possible_transitions: (int, int, int, int)) -> bool:
if possible_transitions[0] == 1:
return 0
if possible_transitions[1] == 1:
return 1
if possible_transitions[2] == 1:
return 2
return 3
def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
if pos_1 is None and pos_2 is None:
return True
if pos_1 is None or pos_2 is None:
return False
return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
def fast_count_nonzero(possible_transitions: (int, int, int, int)):
return possible_transitions[0] + possible_transitions[1] + possible_transitions[2] + possible_transitions[3]
def fast_delete(lis: list, index) -> list:
new_list = lis.copy()
new_list.pop(index)
return new_list
def fast_where(binary_iterable):
return [index for index, element in enumerate(binary_iterable) if element != 0]
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file
from flatland.envs.schedule_generators import schedule_from_file
def load_flatland_environment_from_file(file_name: str,
load_from_package: str = None,
obs_builder_object: ObservationBuilder = None) -> RailEnv:
"""
Parameters
----------
file_name : str
The pickle file.
load_from_package : str
The python module to import from. Example: 'env_data.tests'
This requires that there are `__init__.py` files in the folder structure we load the file from.
obs_builder_object: ObservationBuilder
The obs builder for the `RailEnv` that is created.
Returns
-------
RailEnv
The environment loaded from the pickle file.
"""
if obs_builder_object is None:
obs_builder_object = TreeObsForRailEnv(
max_depth=2,
predictor=ShortestPathPredictorForRailEnv(max_depth=10))
environment = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name, load_from_package),
schedule_generator=schedule_from_file(file_name, load_from_package), number_of_agents=1,
obs_builder_object=obs_builder_object)
return environment
......@@ -160,6 +160,7 @@ def fix_inner_nodes(grid_map: GridTransitionMap, inner_node_pos: IntVector2D, ra
grid_map.grid[tmp_pos] = transition
return
def align_cell_to_city(city_center, city_orientation, cell):
"""
Alig all cells to face the city center along the city orientation
......@@ -171,4 +172,4 @@ def align_cell_to_city(city_center, city_orientation, cell):
if city_orientation % 2 == 0:
return int(2 * np.clip(cell[0] - city_center[0], 0, 1))
else:
return int(2 * np.clip(city_center[1] - cell[1], 0, 1)) + 1
return int(2 * np.clip(city_center[1] - cell[1], 0, 1)) + 1
"""Line generators (railway undertaking, "EVU")."""
import warnings
from typing import Tuple, List, Callable, Mapping, Optional, Any
import numpy as np
from numpy.random.mtrand import RandomState
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.timetable_utils import Line
from flatland.envs import persistence
AgentPosition = Tuple[int, int]
LineGenerator = Callable[[GridTransitionMap, int, Optional[Any], Optional[int]], Line]
def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, float] = None,
seed: int = None, np_random: RandomState = None) -> List[float]:
"""
Parameters
----------
nb_agents : int
The number of agents to generate a speed for
speed_ratio_map : Mapping[float,float]
A map of speeds mappint to their ratio of appearance. The ratios must sum up to 1.
Returns
-------
List[float]
A list of size nb_agents of speeds with the corresponding probabilistic ratios.
"""
if speed_ratio_map is None:
return [1.0] * nb_agents
nb_classes = len(speed_ratio_map.keys())
speed_ratio_map_as_list: List[Tuple[float, float]] = list(speed_ratio_map.items())
speed_ratios = list(map(lambda t: t[1], speed_ratio_map_as_list))
speeds = list(map(lambda t: t[0], speed_ratio_map_as_list))
return list(map(lambda index: speeds[index], np_random.choice(nb_classes, nb_agents, p=speed_ratios)))
class BaseLineGen(object):
def __init__(self, speed_ratio_map: Mapping[float, float] = None, seed: int = 1):
self.speed_ratio_map = speed_ratio_map
self.seed = seed
def generate(self, rail: GridTransitionMap, num_agents: int, hints: Any=None, num_resets: int = 0,
np_random: RandomState = None) -> Line:
pass
def __call__(self, *args, **kwargs):
return self.generate(*args, **kwargs)
def sparse_line_generator(speed_ratio_map: Mapping[float, float] = None, seed: int = 1) -> LineGenerator:
return SparseLineGen(speed_ratio_map, seed)
class SparseLineGen(BaseLineGen):
"""
This is the line generator which is used for Round 2 of the Flatland challenge. It produces lines
to railway networks provided by sparse_rail_generator.
:param speed_ratio_map: Speed ratios of all agents. They are probabilities of all different speeds and have to
add up to 1.
:param seed: Initiate random seed generator
"""
def decide_orientation(self, rail, start, target, possible_orientations, np_random: RandomState) -> int:
feasible_orientations = []
for orientation in possible_orientations:
if rail.check_path_exists(start[0], orientation, target[0]):
feasible_orientations.append(orientation)
if len(feasible_orientations) > 0:
return np_random.choice(feasible_orientations)
else:
return 0
def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict, num_resets: int,
np_random: RandomState) -> Line:
"""
The generator that assigns tasks to all the agents
:param rail: Rail infrastructure given by the rail_generator
:param num_agents: Number of agents to include in the line
:param hints: Hints provided by the rail_generator These include positions of start/target positions
:param num_resets: How often the generator has been reset.
:return: Returns the generator to the rail constructor
"""
_runtime_seed = self.seed + num_resets
train_stations = hints['train_stations']
city_positions = hints['city_positions']
city_orientation = hints['city_orientations']
# Place agents and targets within available train stations
agents_position = []
agents_target = []
agents_direction = []
city1, city2 = None, None
city1_num_stations, city2_num_stations = None, None
city1_possible_orientations, city2_possible_orientations = None, None
for agent_idx in range(num_agents):
if (agent_idx % 2 == 0):
# Setlect 2 cities, find their num_stations and possible orientations
city_idx = np_random.choice(len(city_positions), 2, replace=False)
city1 = city_idx[0]
city2 = city_idx[1]
city1_num_stations = len(train_stations[city1])
city2_num_stations = len(train_stations[city2])
city1_possible_orientations = [city_orientation[city1],
(city_orientation[city1] + 2) % 4]
city2_possible_orientations = [city_orientation[city2],
(city_orientation[city2] + 2) % 4]
# Agent 1 : city1 > city2, Agent 2: city2 > city1
agent_start_idx = ((2 * np_random.randint(0, 10))) % city1_num_stations
agent_target_idx = ((2 * np_random.randint(0, 10)) + 1) % city2_num_stations
agent_start = train_stations[city1][agent_start_idx]
agent_target = train_stations[city2][agent_target_idx]
agent_orientation = self.decide_orientation(
rail, agent_start, agent_target, city1_possible_orientations, np_random)
else:
agent_start_idx = ((2 * np_random.randint(0, 10))) % city2_num_stations
agent_target_idx = ((2 * np_random.randint(0, 10)) + 1) % city1_num_stations
agent_start = train_stations[city2][agent_start_idx]
agent_target = train_stations[city1][agent_target_idx]
agent_orientation = self.decide_orientation(
rail, agent_start, agent_target, city2_possible_orientations, np_random)
# agent1 details
agents_position.append((agent_start[0][0], agent_start[0][1]))
agents_target.append((agent_target[0][0], agent_target[0][1]))
agents_direction.append(agent_orientation)
if self.speed_ratio_map:
speeds = speed_initialization_helper(num_agents, self.speed_ratio_map, seed=_runtime_seed, np_random=np_random)
else:
speeds = [1.0] * len(agents_position)
# We add multiply factors to the max number of time steps to simplify task in Flatland challenge.
# These factors might change in the future.
timedelay_factor = 4
alpha = 2
max_episode_steps = int(
timedelay_factor * alpha * (rail.width + rail.height + num_agents / len(city_positions)))
return Line(agent_positions=agents_position, agent_directions=agents_direction,
agent_targets=agents_target, agent_speeds=speeds)
def line_from_file(filename, load_from_package=None) -> LineGenerator:
"""
Utility to load pickle file
Parameters
----------
input_file : Pickle file generated by env.save() or editor
Returns
-------
Tuple[List[Tuple[int,int]], List[Tuple[int,int]], List[Tuple[int,int]], List[float]]
initial positions, directions, targets speeds
"""
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0,
np_random: RandomState = None) -> Line:
env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package)
max_episode_steps = env_dict.get("max_episode_steps", 0)
if (max_episode_steps==0):
print("This env file has no max_episode_steps (deprecated) - setting to 100")
max_episode_steps = 100
agents = env_dict["agents"]
# setup with loaded data
agents_position = [a.initial_position for a in agents]
# this logic is wrong - we should really load the initial_direction as the direction.
#agents_direction = [a.direction for a in agents]
agents_direction = [a.initial_direction for a in agents]
agents_target = [a.target for a in agents]
agents_speed = [a.speed_counter.speed for a in agents]
return Line(agent_positions=agents_position, agent_directions=agents_direction,
agent_targets=agents_target, agent_speeds=agents_speed)
return generator
"""Malfunction generators for rail systems"""
from typing import Callable, NamedTuple, Optional, Tuple
import numpy as np
from numpy.random.mtrand import RandomState
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.step_utils.states import TrainState
from flatland.envs import persistence
# why do we have both MalfunctionParameters and MalfunctionProcessData - they are both the same!
MalfunctionParameters = NamedTuple('MalfunctionParameters',
[('malfunction_rate', float), ('min_duration', int), ('max_duration', int)])
MalfunctionProcessData = NamedTuple('MalfunctionProcessData',
[('malfunction_rate', float), ('min_duration', int), ('max_duration', int)])
Malfunction = NamedTuple('Malfunction', [('num_broken_steps', int)])
# Why is the return value Optional? We always return a Malfunction.
MalfunctionGenerator = Callable[[RandomState, bool], Malfunction]
def _malfunction_prob(rate: float) -> float:
"""
Probability of a single agent to break. According to Poisson process with given rate
:param rate:
:return:
"""
if rate <= 0:
return 0.
else:
return 1 - np.exp(-rate)
class ParamMalfunctionGen(object):
""" Preserving old behaviour of using MalfunctionParameters for constructor,
but returning MalfunctionProcessData in get_process_data.
Data structure and content is the same.
"""
def __init__(self, parameters: MalfunctionParameters):
#self.mean_malfunction_rate = parameters.malfunction_rate
#self.min_number_of_steps_broken = parameters.min_duration
#self.max_number_of_steps_broken = parameters.max_duration
self.MFP = parameters
def generate(self, np_random: RandomState) -> Malfunction:
if np_random.rand() < _malfunction_prob(self.MFP.malfunction_rate):
num_broken_steps = np_random.randint(self.MFP.min_duration,
self.MFP.max_duration + 1) + 1
else:
num_broken_steps = 0
return Malfunction(num_broken_steps)
def get_process_data(self):
return MalfunctionProcessData(*self.MFP)
class NoMalfunctionGen(ParamMalfunctionGen):
def __init__(self):
super().__init__(MalfunctionParameters(0,0,0))
class FileMalfunctionGen(ParamMalfunctionGen):
def __init__(self, env_dict=None, filename=None, load_from_package=None):
""" uses env_dict if populated, otherwise tries to load from file / package.
"""
if env_dict is None:
env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package)
if env_dict.get('malfunction') is not None:
oMFP = MalfunctionParameters(*env_dict["malfunction"])
else:
oMFP = MalfunctionParameters(0,0,0) # no malfunctions
super().__init__(oMFP)
################################################################################################
# OLD / DEPRECATED generator functions below. To be removed.
def no_malfunction_generator() -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
"""
Malfunction generator which generates no malfunctions
Parameters
----------
Nothing
Returns
-------
generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
"""
print("DEPRECATED - use NoMalfunctionGen instead of no_malfunction_generator")
# Mean malfunction in number of time steps
mean_malfunction_rate = 0.
# Uniform distribution parameters for malfunction duration
min_number_of_steps_broken = 0
max_number_of_steps_broken = 0
def generator(np_random: RandomState = None) -> Malfunction:
return Malfunction(0)
return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
max_number_of_steps_broken)
def single_malfunction_generator(earlierst_malfunction: int, malfunction_duration: int) -> Tuple[
MalfunctionGenerator, MalfunctionProcessData]:
"""
Malfunction generator which guarantees exactly one malfunction during an episode of an ACTIVE agent.
Parameters
----------
earlierst_malfunction: Earliest possible malfunction onset
malfunction_duration: The duration of the single malfunction
Returns
-------
generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
"""
# Mean malfunction in number of time steps
mean_malfunction_rate = 0.
# Uniform distribution parameters for malfunction duration
min_number_of_steps_broken = 0
max_number_of_steps_broken = 0
# Keep track of the total number of malfunctions in the env
global_nr_malfunctions = 0
# Malfunction calls per agent
malfunction_calls = dict()
def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
# We use the global variable to assure only a single malfunction in the env
nonlocal global_nr_malfunctions
nonlocal malfunction_calls
# Reset malfunciton generator
if reset:
nonlocal global_nr_malfunctions
nonlocal malfunction_calls
global_nr_malfunctions = 0
malfunction_calls = dict()
return Malfunction(0)
# No more malfunctions if we already had one, ignore all updates
if global_nr_malfunctions > 0:
return Malfunction(0)
# Update number of calls per agent
if agent.handle in malfunction_calls:
malfunction_calls[agent.handle] += 1
else:
malfunction_calls[agent.handle] = 1
# Break an agent that is active at the time of the malfunction
if (agent.state == TrainState.MOVING or agent.state == TrainState.STOPPED) \
and malfunction_calls[agent.handle] >= earlierst_malfunction: #TODO : Dipam : Is this needed?
global_nr_malfunctions += 1
return Malfunction(malfunction_duration)
else:
return Malfunction(0)
return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
max_number_of_steps_broken)
def malfunction_from_file(filename: str, load_from_package=None) -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
"""
Utility to load pickle file
Parameters
----------
input_file : Pickle file generated by env.save() or editor
Returns
-------
generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
"""
print("DEPRECATED - use FileMalfunctionGen instead of malfunction_from_file")
env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package)
# TODO: make this better by using namedtuple in the pickle file. See issue 282
if env_dict.get('malfunction') is not None:
env_dict['malfunction'] = oMPD = MalfunctionProcessData._make(env_dict['malfunction'])
else:
oMPD = None
if oMPD is not None:
# Mean malfunction in number of time steps
mean_malfunction_rate = oMPD.malfunction_rate
# Uniform distribution parameters for malfunction duration
min_number_of_steps_broken = oMPD.min_duration
max_number_of_steps_broken = oMPD.max_duration
else:
# Mean malfunction in number of time steps
mean_malfunction_rate = 0.
# Uniform distribution parameters for malfunction duration
min_number_of_steps_broken = 0
max_number_of_steps_broken = 0
def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
"""
Generate malfunctions for agents
Parameters
----------
agent
np_random
Returns
-------
int: Number of time steps an agent is broken
"""
# Dummy reset function as we don't implement specific seeding here
if reset:
return Malfunction(0)
if agent.malfunction_handler.malfunction_down_counter < 1:
if np_random.rand() < _malfunction_prob(mean_malfunction_rate):
num_broken_steps = np_random.randint(min_number_of_steps_broken,
max_number_of_steps_broken + 1) + 1
return Malfunction(num_broken_steps)
return Malfunction(0)
return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
max_number_of_steps_broken)
def malfunction_from_params(parameters: MalfunctionParameters) -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
"""
Utility to load malfunction from parameters
Parameters
----------
parameters : contains all the parameters of the malfunction
malfunction_rate : float rate per timestep at which each agent malfunctions
min_duration : int minimal duration of a failure
max_number_of_steps_broken : int maximal duration of a failure
Returns
-------
generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
"""
print("DEPRECATED - use ParamMalfunctionGen instead of malfunction_from_params")
mean_malfunction_rate = parameters.malfunction_rate
min_number_of_steps_broken = parameters.min_duration
max_number_of_steps_broken = parameters.max_duration
def generator(np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
"""
Generate malfunctions for agents
Parameters
----------
agent
np_random
Returns
-------
int: Number of time steps an agent is broken
"""
# Dummy reset function as we don't implement specific seeding here
if reset:
return Malfunction(0)
if np_random.rand() < _malfunction_prob(mean_malfunction_rate):
num_broken_steps = np_random.randint(min_number_of_steps_broken,
max_number_of_steps_broken + 1)
return Malfunction(num_broken_steps)
return Malfunction(0)
return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
max_number_of_steps_broken)
......@@ -11,9 +11,25 @@ from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.grid.grid_utils import coordinate_to_position
from flatland.envs.agent_utils import RailAgentStatus, EnvAgent
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.fast_methods import fast_argmax, fast_count_nonzero, fast_position_equal, fast_delete, fast_where
from flatland.envs.step_utils.states import TrainState
from flatland.utils.ordered_set import OrderedSet
Node = collections.namedtuple('Node', 'dist_own_target_encountered '
'dist_other_target_encountered '
'dist_other_agent_encountered '
'dist_potential_conflict '
'dist_unusable_switch '
'dist_to_next_branch '
'dist_min_to_target '
'num_agents_same_direction '
'num_agents_opposite_direction '
'num_agents_malfunctioning '
'speed_min_fractional '
'num_agents_ready_to_depart '
'childs')
class TreeObsForRailEnv(ObservationBuilder):
"""
......@@ -25,19 +41,6 @@ class TreeObsForRailEnv(ObservationBuilder):
For details about the features in the tree observation see the get() function.
"""
Node = collections.namedtuple('Node', 'dist_own_target_encountered '
'dist_other_target_encountered '
'dist_other_agent_encountered '
'dist_potential_conflict '
'dist_unusable_switch '
'dist_to_next_branch '
'dist_min_to_target '
'num_agents_same_direction '
'num_agents_opposite_direction '
'num_agents_malfunctioning '
'speed_min_fractional '
'num_agents_ready_to_depart '
'childs')
tree_explored_actions_char = ['L', 'F', 'R', 'B']
......@@ -90,18 +93,21 @@ class TreeObsForRailEnv(ObservationBuilder):
self.location_has_agent_ready_to_depart = {}
for _agent in self.env.agents:
if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \
if not _agent.state.is_off_map_state() and \
_agent.position:
self.location_has_agent[tuple(_agent.position)] = 1
self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction
self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed']
self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data[
'malfunction']
self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_counter.speed
self.location_has_agent_malfunction[tuple(_agent.position)] = \
_agent.malfunction_handler.malfunction_down_counter
if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \
# [NIMISH] WHAT IS THIS
if _agent.state.is_off_map_state() and \
_agent.initial_position:
self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \
self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1
self.location_has_agent_ready_to_depart.setdefault(tuple(_agent.initial_position), 0)
self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] += 1
# self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \
# self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1
observations = super().get_many(handles)
......@@ -190,32 +196,34 @@ class TreeObsForRailEnv(ObservationBuilder):
print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
agent = self.env.agents[handle] # TODO: handle being treated as index
if agent.status == RailAgentStatus.READY_TO_DEPART:
if agent.state.is_off_map_state():
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
elif agent.state.is_on_map_state():
agent_virtual_position = agent.position
elif agent.status == RailAgentStatus.DONE:
elif agent.state == TrainState.DONE:
agent_virtual_position = agent.target
else:
return None
possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
num_transitions = fast_count_nonzero(possible_transitions)
# Here information about the agent itself is stored
distance_map = self.env.distance_map.get()
root_node_observation = TreeObsForRailEnv.Node(dist_own_target_encountered=0, dist_other_target_encountered=0,
dist_other_agent_encountered=0, dist_potential_conflict=0,
dist_unusable_switch=0, dist_to_next_branch=0,
dist_min_to_target=distance_map[
(handle, *agent_virtual_position,
agent.direction)],
num_agents_same_direction=0, num_agents_opposite_direction=0,
num_agents_malfunctioning=agent.malfunction_data['malfunction'],
speed_min_fractional=agent.speed_data['speed'],
num_agents_ready_to_depart=0,
childs={})
# was referring to TreeObsForRailEnv.Node
root_node_observation = Node(dist_own_target_encountered=0, dist_other_target_encountered=0,
dist_other_agent_encountered=0, dist_potential_conflict=0,
dist_unusable_switch=0, dist_to_next_branch=0,
dist_min_to_target=distance_map[
(handle, *agent_virtual_position,
agent.direction)],
num_agents_same_direction=0, num_agents_opposite_direction=0,
num_agents_malfunctioning=agent.malfunction_handler.malfunction_down_counter,
speed_min_fractional=agent.speed_counter.speed,
num_agents_ready_to_depart=0,
childs={})
# print("root node type:", type(root_node_observation))
visited = OrderedSet()
......@@ -225,7 +233,7 @@ class TreeObsForRailEnv(ObservationBuilder):
orientation = agent.direction
if num_transitions == 1:
orientation = np.argmax(possible_transitions)
orientation = fast_argmax(possible_transitions)
for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]):
......@@ -265,8 +273,11 @@ class TreeObsForRailEnv(ObservationBuilder):
last_is_target = False
visited = OrderedSet()
agent = self.env.agents[handle]
time_per_cell = np.reciprocal(agent.speed_data["speed"])
distance_map_handle = self.env.distance_map.get()[handle]
time_per_cell = 1.0 / agent.speed_counter.speed
own_target_encountered = np.inf
other_agent_encountered = np.inf
other_target_encountered = np.inf
......@@ -283,7 +294,7 @@ class TreeObsForRailEnv(ObservationBuilder):
# #############################
# Modify here to compute any useful data required to build the end node's features. This code is called
# for each cell visited between the previous branching node and the next switch / target / dead-end.
if position in self.location_has_agent:
if self.location_has_agent.get(position, 0) == 1:
if tot_dist < other_agent_encountered:
other_agent_encountered = tot_dist
......@@ -295,21 +306,16 @@ class TreeObsForRailEnv(ObservationBuilder):
if self.location_has_agent_direction[position] == direction:
# Cummulate the number of agents on branch with same direction
other_agent_same_direction += self.location_has_agent_direction.get((position, direction), 0)
other_agent_same_direction += 1
# Check fractional speed of agents
current_fractional_speed = self.location_has_agent_speed[position]
if current_fractional_speed < min_fractional_speed:
min_fractional_speed = current_fractional_speed
# Other direction agents
# TODO: Test that this behavior is as expected
other_agent_opposite_direction += \
self.location_has_agent[position] - self.location_has_agent_direction.get((position, direction),
0)
else:
# If no agent in the same direction was found all agents in that position are other direction
# Attention this counts to many agents as a few might be going off on a switch.
other_agent_opposite_direction += self.location_has_agent[position]
# Check number of possible transitions for agent and total number of transitions in cell (type)
......@@ -330,36 +336,34 @@ class TreeObsForRailEnv(ObservationBuilder):
post_step = min(self.max_prediction_depth - 1, predicted_time + 1)
# Look for conflicting paths at distance tot_dist
if int_position in np.delete(self.predicted_pos[predicted_time], handle, 0):
conflicting_agent = np.where(self.predicted_pos[predicted_time] == int_position)
for ca in conflicting_agent[0]:
if int_position in fast_delete(self.predicted_pos[predicted_time], handle):
conflicting_agent = fast_where(self.predicted_pos[predicted_time] == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[predicted_time][ca] and cell_transitions[
self._reverse_dir(
self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict:
potential_conflict = tot_dist
if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
if self.env.agents[ca].state == TrainState.DONE and tot_dist < potential_conflict:
potential_conflict = tot_dist
# Look for conflicting paths at distance num_step-1
elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0):
conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
for ca in conflicting_agent[0]:
elif int_position in fast_delete(self.predicted_pos[pre_step], handle):
conflicting_agent = fast_where(self.predicted_pos[pre_step] == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[pre_step][ca] \
and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \
and tot_dist < potential_conflict: # noqa: E125
potential_conflict = tot_dist
if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
if self.env.agents[ca].state == TrainState.DONE and tot_dist < potential_conflict:
potential_conflict = tot_dist
# Look for conflicting paths at distance num_step+1
elif int_position in np.delete(self.predicted_pos[post_step], handle, 0):
conflicting_agent = np.where(self.predicted_pos[post_step] == int_position)
for ca in conflicting_agent[0]:
elif int_position in fast_delete(self.predicted_pos[post_step], handle):
conflicting_agent = fast_where(self.predicted_pos[post_step] == int_position)
for ca in conflicting_agent:
if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir(
self.predicted_dir[post_step][ca])] == 1 \
and tot_dist < potential_conflict: # noqa: E125
potential_conflict = tot_dist
if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
if self.env.agents[ca].state == TrainState.DONE and tot_dist < potential_conflict:
potential_conflict = tot_dist
if position in self.location_has_target and position != agent.target:
......@@ -377,7 +381,7 @@ class TreeObsForRailEnv(ObservationBuilder):
visited.add((position[0], position[1], direction))
# If the target node is encountered, pick that as node. Also, no further branching is possible.
if np.array_equal(position, self.env.agents[handle].target):
if fast_position_equal(position, self.env.agents[handle].target):
last_is_target = True
break
......@@ -385,7 +389,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if crossing_found:
# Treat the crossing as a straight rail cell
total_transitions = 2
num_transitions = np.count_nonzero(cell_transitions)
num_transitions = fast_count_nonzero(cell_transitions)
exploring = False
......@@ -404,7 +408,7 @@ class TreeObsForRailEnv(ObservationBuilder):
# Keep walking through the tree along `direction`
exploring = True
# convert one-hot encoding to 0,1,2,3
direction = np.argmax(cell_transitions)
direction = fast_argmax(cell_transitions)
position = get_new_position(position, direction)
num_steps += 1
tot_dist += 1
......@@ -431,24 +435,25 @@ class TreeObsForRailEnv(ObservationBuilder):
dist_min_to_target = 0
elif last_is_terminal:
dist_to_next_branch = np.inf
dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction]
dist_min_to_target = distance_map_handle[position[0], position[1], direction]
else:
dist_to_next_branch = tot_dist
dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction]
node = TreeObsForRailEnv.Node(dist_own_target_encountered=own_target_encountered,
dist_other_target_encountered=other_target_encountered,
dist_other_agent_encountered=other_agent_encountered,
dist_potential_conflict=potential_conflict,
dist_unusable_switch=unusable_switch,
dist_to_next_branch=dist_to_next_branch,
dist_min_to_target=dist_min_to_target,
num_agents_same_direction=other_agent_same_direction,
num_agents_opposite_direction=other_agent_opposite_direction,
num_agents_malfunctioning=malfunctioning_agent,
speed_min_fractional=min_fractional_speed,
num_agents_ready_to_depart=other_agent_ready_to_depart_encountered,
childs={})
dist_min_to_target = distance_map_handle[position[0], position[1], direction]
# TreeObsForRailEnv.Node
node = Node(dist_own_target_encountered=own_target_encountered,
dist_other_target_encountered=other_target_encountered,
dist_other_agent_encountered=other_agent_encountered,
dist_potential_conflict=potential_conflict,
dist_unusable_switch=unusable_switch,
dist_to_next_branch=dist_to_next_branch,
dist_min_to_target=dist_min_to_target,
num_agents_same_direction=other_agent_same_direction,
num_agents_opposite_direction=other_agent_opposite_direction,
num_agents_malfunctioning=malfunctioning_agent,
speed_min_fractional=min_fractional_speed,
num_agents_ready_to_depart=other_agent_ready_to_depart_encountered,
childs={})
# #############################
# #############################
......@@ -563,11 +568,11 @@ class GlobalObsForRailEnv(ObservationBuilder):
def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
agent = self.env.agents[handle]
if agent.status == RailAgentStatus.READY_TO_DEPART:
if agent.state.is_off_map_state():
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
elif agent.state.is_on_map_state():
agent_virtual_position = agent.position
elif agent.status == RailAgentStatus.DONE:
elif agent.state == TrainState.DONE:
agent_virtual_position = agent.target
else:
return None
......@@ -588,7 +593,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
other_agent: EnvAgent = self.env.agents[i]
# ignore other agents not in the grid any more
if other_agent.status == RailAgentStatus.DONE_REMOVED:
if other_agent.state == TrainState.DONE:
continue
obs_targets[other_agent.target][1] = 1
......@@ -598,10 +603,10 @@ class GlobalObsForRailEnv(ObservationBuilder):
# second channel only for other agents
if i != handle:
obs_agents_state[other_agent.position][1] = other_agent.direction
obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction']
obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed']
obs_agents_state[other_agent.position][2] = other_agent.malfunction_handler.malfunction_down_counter
obs_agents_state[other_agent.position][3] = other_agent.speed_counter.speed
# fifth channel: all ready to depart on this position
if other_agent.status == RailAgentStatus.READY_TO_DEPART:
if other_agent.state.is_off_map_state():
obs_agents_state[other_agent.initial_position][4] += 1
return self.rail_obs, obs_agents_state, obs_targets
......
import pickle
import msgpack
import numpy as np
import msgpack_numpy
msgpack_numpy.patch()
from flatland.envs import rail_env
from flatland.core.env_observation_builder import DummyObservationBuilder
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgent, load_env_agent
# cannot import objects / classes directly because of circular import
from flatland.envs import malfunction_generators as mal_gen
from flatland.envs import rail_generators as rail_gen
from flatland.envs import line_generators as line_gen
class RailEnvPersister(object):
@classmethod
def save(cls, env, filename, save_distance_maps=False):
"""
Saves environment and distance map information in a file
Parameters:
---------
filename: string
save_distance_maps: bool
"""
env_dict = cls.get_full_state(env)
# We have an unresolved problem with msgpack loading the list of agents
# see also 20 lines below.
# print(f"env save - agents: {env_dict['agents'][0]}")
# a0 = env_dict["agents"][0]
# print("agent type:", type(a0))
if save_distance_maps is True:
oDistMap = env.distance_map.get()
if oDistMap is not None:
if len(oDistMap) > 0:
env_dict["distance_map"] = oDistMap
else:
print("[WARNING] Unable to save the distance map for this environment, as none was found !")
else:
print("[WARNING] Unable to save the distance map for this environment, as none was found !")
with open(filename, "wb") as file_out:
if filename.endswith("mpk"):
data = msgpack.packb(env_dict)
elif filename.endswith("pkl"):
data = pickle.dumps(env_dict)
#pickle.dump(env_dict, file_out)
file_out.write(data)
# We have an unresovled problem with msgpack loading the list of Agents
# with open(filename, "rb") as file_in:
# if filename.endswith("mpk"):
# bytes_in = file_in.read()
# dIn = msgpack.unpackb(data, encoding="utf-8")
# print(f"msgpack check - {dIn.keys()}")
# print(f"msgpack check - {dIn['agents'][0]}")
@classmethod
def save_episode(cls, env, filename):
dict_env = cls.get_full_state(env)
# Add additional info to dict_env before saving
dict_env["episode"] = env.cur_episode
dict_env["actions"] = env.list_actions
dict_env["shape"] = (env.width, env.height)
dict_env["max_episode_steps"] = env._max_episode_steps
with open(filename, "wb") as file_out:
if filename.endswith(".mpk"):
file_out.write(msgpack.packb(dict_env))
elif filename.endswith(".pkl"):
pickle.dump(dict_env, file_out)
@classmethod
def load(cls, env, filename, load_from_package=None):
"""
Load environment with distance map from a file
Parameters:
-------
filename: string
"""
env_dict = cls.load_env_dict(filename, load_from_package=load_from_package)
cls.set_full_state(env, env_dict)
@classmethod
def load_new(cls, filename, load_from_package=None):
env_dict = cls.load_env_dict(filename, load_from_package=load_from_package)
llGrid = env_dict["grid"]
height = len(llGrid)
width = len(llGrid[0])
# TODO: inefficient - each one of these generators loads the complete env file.
env = rail_env.RailEnv(#width=1, height=1,
width=width, height=height,
rail_generator=rail_gen.rail_from_file(filename,
load_from_package=load_from_package),
line_generator=line_gen.line_from_file(filename,
load_from_package=load_from_package),
#malfunction_generator_and_process_data=mal_gen.malfunction_from_file(filename,
# load_from_package=load_from_package),
malfunction_generator=mal_gen.FileMalfunctionGen(env_dict),
obs_builder_object=DummyObservationBuilder(),
record_steps=True)
env.rail = GridTransitionMap(1,1) # dummy
cls.set_full_state(env, env_dict)
return env, env_dict
@classmethod
def load_env_dict(cls, filename, load_from_package=None):
if load_from_package is not None:
from importlib_resources import read_binary
load_data = read_binary(load_from_package, filename)
else:
with open(filename, "rb") as file_in:
load_data = file_in.read()
if filename.endswith("mpk"):
env_dict = msgpack.unpackb(load_data, use_list=False, raw=False)
elif filename.endswith("pkl"):
try:
env_dict = pickle.loads(load_data)
except ValueError:
print("pickle failed to load file:", filename, " trying msgpack (deprecated)...")
env_dict = msgpack.unpackb(load_data, use_list=False, raw=False)
else:
print(f"filename {filename} must end with either pkl or mpk")
env_dict = {}
# Replace the agents tuple with EnvAgent objects
if "agents_static" in env_dict:
env_dict["agents"] = EnvAgent.load_legacy_static_agent(env_dict["agents_static"])
# remove the legacy key
del env_dict["agents_static"]
elif "agents" in env_dict:
# env_dict["agents"] = [EnvAgent(*d[0:len(d)]) for d in env_dict["agents"]]
env_dict["agents"] = [load_env_agent(d) for d in env_dict["agents"]]
return env_dict
@classmethod
def load_resource(cls, package, resource):
"""
Load environment (with distance map?) from a binary
"""
#from importlib_resources import read_binary
#load_data = read_binary(package, resource)
#if resource.endswith("pkl"):
# env_dict = pickle.loads(load_data)
#elif resource.endswith("mpk"):
# env_dict = msgpack.unpackb(load_data, encoding="utf-8")
#cls.set_full_state(env, env_dict)
return cls.load_new(resource, load_from_package=package)
@classmethod
def set_full_state(cls, env, env_dict):
"""
Sets environment state from env_dict
Parameters
-------
env_dict: dict
"""
env.rail.grid = np.array(env_dict["grid"])
# Initialise the env with the frozen agents in the file
env.agents = env_dict.get("agents", [])
# For consistency, set number_of_agents, which is the number which will be generated on reset
env.number_of_agents = env.get_num_agents()
env.height, env.width = env.rail.grid.shape
env.rail.height = env.height
env.rail.width = env.width
env.dones = dict.fromkeys(list(range(env.get_num_agents())) + ["__all__"], False)
@classmethod
def get_full_state(cls, env):
"""
Returns state of environment in dict object, ready for serialization
"""
grid_data = env.rail.grid.tolist()
# msgpack cannot persist EnvAgent so use the Agent namedtuple.
agent_data = [agent.to_agent() for agent in env.agents]
#print("get_full_state - agent_data:", agent_data)
malfunction_data: mal_gen.MalfunctionProcessData = env.malfunction_process_data
msg_data_dict = {
"grid": grid_data,
"agents": agent_data,
"malfunction": malfunction_data,
"max_episode_steps": env._max_episode_steps,
}
return msg_data_dict
################################################################################################
# deprecated methods moved from RailEnv. Most likely broken.
def deprecated_get_full_state_msg(self) -> msgpack.Packer:
"""
Returns state of environment in msgpack object
"""
msg_data_dict = self.get_full_state_dict()
return msgpack.packb(msg_data_dict, use_bin_type=True)
def deprecated_get_agent_state_msg(self) -> msgpack.Packer:
"""
Returns agents information in msgpack object
"""
agent_data = [agent.to_agent() for agent in self.agents]
msg_data = {
"agents": agent_data}
return msgpack.packb(msg_data, use_bin_type=True)
def deprecated_get_full_state_dist_msg(self) -> msgpack.Packer:
"""
Returns environment information with distance map information as msgpack object
"""
grid_data = self.rail.grid.tolist()
agent_data = [agent.to_agent() for agent in self.agents]
# I think these calls do nothing - they create packed data and it is discarded
#msgpack.packb(grid_data, use_bin_type=True)
#msgpack.packb(agent_data, use_bin_type=True)
distance_map_data = self.distance_map.get()
malfunction_data: mal_gen.MalfunctionProcessData = self.malfunction_process_data
#msgpack.packb(distance_map_data, use_bin_type=True) # does nothing
msg_data = {
"grid": grid_data,
"agents": agent_data,
"distance_map": distance_map_data,
"malfunction": malfunction_data}
return msgpack.packb(msg_data, use_bin_type=True)
def deprecated_set_full_state_msg(self, msg_data):
"""
Sets environment state with msgdata object passed as argument
Parameters
-------
msg_data: msgpack object
"""
data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
self.rail.grid = np.array(data["grid"])
# agents are always reset as not moving
if "agents_static" in data:
self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"])
else:
self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
# setup with loaded data
self.height, self.width = self.rail.grid.shape
self.rail.height = self.height
self.rail.width = self.width
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
def deprecated_set_full_state_dist_msg(self, msg_data):
"""
Sets environment grid state and distance map with msgdata object passed as argument
Parameters
-------
msg_data: msgpack object
"""
data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
self.rail.grid = np.array(data["grid"])
# agents are always reset as not moving
if "agents_static" in data:
self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"])
else:
self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
if "distance_map" in data.keys():
self.distance_map.set(data["distance_map"])
# setup with loaded data
self.height, self.width = self.rail.grid.shape
self.rail.height = self.height
self.rail.width = self.width
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
......@@ -5,11 +5,12 @@ Collection of environment-specific PredictionBuilder.
import numpy as np
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.distance_map import DistanceMap
from flatland.envs.rail_env import RailEnvActions
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
from flatland.utils.ordered_set import OrderedSet
from flatland.envs.step_utils.states import TrainState
from flatland.envs.step_utils import transition_utils
class DummyPredictorForRailEnv(PredictionBuilder):
......@@ -48,7 +49,7 @@ class DummyPredictorForRailEnv(PredictionBuilder):
prediction_dict = {}
for agent in agents:
if agent.status != RailAgentStatus.ACTIVE:
if not agent.state.is_on_map_state():
# TODO make this generic
continue
action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
......@@ -64,8 +65,8 @@ class DummyPredictorForRailEnv(PredictionBuilder):
continue
for action in action_priorities:
cell_is_free, new_cell_isValid, new_direction, new_position, transition_isValid = \
self.env._check_action_on_agent(action, agent)
new_cell_isValid, new_direction, new_position, transition_isValid = \
transition_utils.check_action_on_agent(action, self.env.rail, agent.position, agent.direction)
if all([new_cell_isValid, transition_isValid]):
# move and change direction to face the new_direction that was
# performed
......@@ -126,12 +127,11 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
prediction_dict = {}
for agent in agents:
if agent.status == RailAgentStatus.READY_TO_DEPART:
if agent.state.is_off_map_state():
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
elif agent.state.is_on_map_state():
agent_virtual_position = agent.position
elif agent.status == RailAgentStatus.DONE:
elif agent.state == TrainState.DONE:
agent_virtual_position = agent.target
else:
......@@ -142,7 +142,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
continue
agent_virtual_direction = agent.direction
agent_speed = agent.speed_data["speed"]
agent_speed = agent.speed_counter.speed
times_per_cell = int(np.reciprocal(agent_speed))
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *agent_virtual_position, agent_virtual_direction, 0]
......
......@@ -2,52 +2,36 @@
Definition of the RailEnv environment.
"""
import random
# TODO: _ this is a global method --> utils or remove later
from enum import IntEnum
from typing import List, NamedTuple, Optional, Dict
import msgpack
import msgpack_numpy as m
from typing import List, Optional, Dict, Tuple
import numpy as np
from gym.utils import seeding
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.grid.grid_utils import IntVector2D
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.distance_map import DistanceMap
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_generators import random_rail_generator, RailGenerator
from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
m.patch()
from flatland.envs.rail_env_action import RailEnvActions
class RailEnvActions(IntEnum):
DO_NOTHING = 0 # implies change of direction in a dead-end!
MOVE_LEFT = 1
MOVE_FORWARD = 2
MOVE_RIGHT = 3
STOP_MOVING = 4
from flatland.envs import malfunction_generators as mal_gen
from flatland.envs import rail_generators as rail_gen
from flatland.envs import line_generators as line_gen
from flatland.envs.timetable_generators import timetable_generator
from flatland.envs import persistence
from flatland.envs import agent_chains as ac
from flatland.envs.fast_methods import fast_position_equal
@staticmethod
def to_char(a: int):
return {
0: 'B',
1: 'L',
2: 'F',
3: 'R',
4: 'S',
}[a]
RailEnvGridPos = NamedTuple('RailEnvGridPos', [('r', int), ('c', int)])
RailEnvNextAction = NamedTuple('RailEnvNextAction', [('action', RailEnvActions), ('next_position', RailEnvGridPos),
('next_direction', Grid4TransitionsEnum)])
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.timetable_generators import timetable_generator
from flatland.envs.step_utils.states import TrainState, StateTransitionSignals
from flatland.envs.step_utils.transition_utils import check_valid_action
from flatland.envs.step_utils import action_preprocessing
from flatland.envs.step_utils import env_utils
class RailEnv(Environment):
"""
......@@ -78,8 +62,8 @@ class RailEnv(Environment):
It costs each agent a step_penalty for every time-step taken in the environment. Independent of the movement
of the agent. Currently all other penalties such as penalty for stopping, starting and invalid actions are set to 0.
alpha = 1
beta = 1
alpha = 0
beta = 0
Reward function parameters:
- invalid_action_penalty = 0
......@@ -101,26 +85,31 @@ class RailEnv(Environment):
For Round 2, they will be passed to the constructor as arguments, to allow for more flexibility.
"""
alpha = 1.0
beta = 1.0
# Epsilon to avoid rounding errors
epsilon = 0.01
invalid_action_penalty = 0 # previously -2; GIACOMO: we decided that invalid actions will carry no penalty
# NEW : REW: Sparse Reward
alpha = 0
beta = 0
step_penalty = -1 * alpha
global_reward = 1 * beta
invalid_action_penalty = 0 # previously -2; GIACOMO: we decided that invalid actions will carry no penalty
stop_penalty = 0 # penalty for stopping a moving agent
start_penalty = 0 # penalty for starting a stopped agent
cancellation_factor = 1
cancellation_time_buffer = 0
def __init__(self,
width,
height,
rail_generator: RailGenerator = random_rail_generator(),
schedule_generator: ScheduleGenerator = random_schedule_generator(),
number_of_agents=1,
rail_generator=None,
line_generator=None, # : line_gen.LineGenerator = line_gen.random_line_generator(),
number_of_agents=2,
obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(),
stochastic_data=None,
malfunction_generator_and_process_data=None, # mal_gen.no_malfunction_generator(),
malfunction_generator=None,
remove_agents_at_target=True,
random_seed=1
random_seed=None,
record_steps=False,
):
"""
Environment init.
......@@ -132,12 +121,12 @@ class RailEnv(Environment):
height and agents handles of a rail environment, along with the number of times
the env has been reset, and returns a GridTransitionMap object and a list of
starting positions, targets, and initial orientations for agent handle.
The rail_generator can pass a distance map in the hints or information for specific schedule_generators.
The rail_generator can pass a distance map in the hints or information for specific line_generators.
Implementations can be found in flatland/envs/rail_generators.py
schedule_generator : function
The schedule_generator function is a function that takes the grid, the number of agents and optional hints
line_generator : function
The line_generator function is a function that takes the grid, the number of agents and optional hints
and returns a list of starting positions, targets, initial orientations and speed for all agent handles.
Implementations can be found in flatland/envs/schedule_generators.py
Implementations can be found in flatland/envs/line_generators.py
width : int
The width of the rail map. Potentially in the future,
a range of widths to sample from.
......@@ -159,68 +148,76 @@ class RailEnv(Environment):
"""
super().__init__()
self.rail_generator: RailGenerator = rail_generator
self.schedule_generator: ScheduleGenerator = schedule_generator
if malfunction_generator_and_process_data is not None:
print("DEPRECATED - RailEnv arg: malfunction_and_process_data - use malfunction_generator")
self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
elif malfunction_generator is not None:
self.malfunction_generator = malfunction_generator
# malfunction_process_data is not used
# self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
self.malfunction_process_data = self.malfunction_generator.get_process_data()
# replace default values here because we can't use default args values because of cyclic imports
else:
self.malfunction_generator = mal_gen.NoMalfunctionGen()
self.malfunction_process_data = self.malfunction_generator.get_process_data()
self.number_of_agents = number_of_agents
if rail_generator is None:
rail_generator = rail_gen.sparse_rail_generator()
self.rail_generator = rail_generator
if line_generator is None:
line_generator = line_gen.sparse_line_generator()
self.line_generator = line_generator
self.rail: Optional[GridTransitionMap] = None
self.width = width
self.height = height
self.remove_agents_at_target = remove_agents_at_target
self.rewards = [0] * number_of_agents
self.done = False
self.obs_builder = obs_builder_object
self.obs_builder.set_env(self)
self._max_episode_steps: Optional[int] = None
self._elapsed_steps = 0
self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False)
self.obs_dict = {}
self.rewards_dict = {}
self.dev_obs_dict = {}
self.dev_pred_dict = {}
self.agents: List[EnvAgent] = []
self.number_of_agents = number_of_agents
self.num_resets = 0
self.distance_map = DistanceMap(self.agents, self.height, self.width)
self.action_space = [5]
self._seed()
self._seed()
self.random_seed = random_seed
if self.random_seed:
if random_seed:
self._seed(seed=random_seed)
# Stochastic train malfunctioning parameters
if stochastic_data is not None:
mean_malfunction_rate = stochastic_data['malfunction_rate']
malfunction_min_duration = stochastic_data['min_duration']
malfunction_max_duration = stochastic_data['max_duration']
else:
mean_malfunction_rate = 0.
malfunction_min_duration = 0.
malfunction_max_duration = 0.
# Mean malfunction in number of time steps
self.mean_malfunction_rate = mean_malfunction_rate
self.agent_positions = None
# Uniform distribution parameters for malfunction duration
self.min_number_of_steps_broken = malfunction_min_duration
self.max_number_of_steps_broken = malfunction_max_duration
# save episode timesteps ie agent positions, orientations. (not yet actions / observations)
self.record_steps = record_steps # whether to save timesteps
# save timesteps in here: [[[row, col, dir, malfunction],...nAgents], ...nSteps]
self.cur_episode = []
self.list_actions = [] # save actions in here
self.valid_positions = None
# global numpy array of agents position, -1 means that the cell is free, otherwise the agent handle is placed
# inside the cell
self.agent_positions: np.ndarray = np.zeros((height, width), dtype=int) - 1
self.motionCheck = ac.MotionCheck()
def _seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
random.seed(seed)
self.random_seed = seed
# Keep track of all the seeds in order
if not hasattr(self, 'seed_history'):
self.seed_history = [seed]
if self.seed_history[-1] != seed:
self.seed_history.append(seed)
return [seed]
# no more agent_handles
......@@ -237,45 +234,13 @@ class RailEnv(Environment):
self.agents.append(agent)
return len(self.agents) - 1
def set_agent_active(self, handle: int):
agent = self.agents[handle]
if agent.status == RailAgentStatus.READY_TO_DEPART and self.cell_free(agent.initial_position):
agent.status = RailAgentStatus.ACTIVE
self._set_agent_to_initial_position(agent, agent.initial_position)
def restart_agents(self):
def reset_agents(self):
""" Reset the agents to their starting positions
"""
for agent in self.agents:
agent.reset()
self.active_agents = [i for i in range(len(self.agents))]
@staticmethod
def compute_max_episode_steps(width: int, height: int, ratio_nr_agents_to_nr_cities: float = 20.0) -> int:
"""
compute_max_episode_steps(width, height, ratio_nr_agents_to_nr_cities, timedelay_factor, alpha)
The method computes the max number of episode steps allowed
Parameters
----------
width : int
width of environment
height : int
height of environment
ratio_nr_agents_to_nr_cities : float, optional
number_of_agents/number_of_cities
Returns
-------
max_episode_steps: int
maximum number of episode steps
"""
timedelay_factor = 4
alpha = 2
return int(timedelay_factor * alpha * (width + height + ratio_nr_agents_to_nr_cities))
def action_required(self, agent):
"""
Check if an agent needs to provide an action
......@@ -290,12 +255,11 @@ class RailEnv(Environment):
True: Agent needs to provide an action
False: Agent cannot provide an action
"""
return (agent.status == RailAgentStatus.READY_TO_DEPART or (
agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0,
rtol=1e-03)))
return agent.state == TrainState.READY_TO_DEPART or \
( agent.state.is_on_map_state() and agent.speed_counter.is_cell_entry )
def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False,
random_seed: bool = None) -> (Dict, Dict):
def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, *,
random_seed: int = None) -> Tuple[Dict, Dict]:
"""
reset(regenerate_rail, regenerate_schedule, activate_agents, random_seed)
......@@ -307,9 +271,7 @@ class RailEnv(Environment):
regenerate the rails
regenerate_schedule : bool, optional
regenerate the schedule and the static agents
activate_agents : bool, optional
activate the agents
random_seed : bool, optional
random_seed : int, optional
random seed for environment
Returns
......@@ -325,7 +287,15 @@ class RailEnv(Environment):
optionals = {}
if regenerate_rail or self.rail is None:
rail, optionals = self.rail_generator(self.width, self.height, self.number_of_agents, self.num_resets)
if "__call__" in dir(self.rail_generator):
rail, optionals = self.rail_generator(
self.width, self.height, self.number_of_agents, self.num_resets, self.np_random)
elif "generate" in dir(self.rail_generator):
rail, optionals = self.rail_generator.generate(
self.width, self.height, self.number_of_agents, self.num_resets, self.np_random)
else:
raise ValueError("Could not invoke __call__ or generate on rail_generator")
self.rail = rail
self.height, self.width = self.rail.grid.shape
......@@ -343,601 +313,359 @@ class RailEnv(Environment):
if optionals and 'agents_hints' in optionals:
agents_hints = optionals['agents_hints']
schedule = self.schedule_generator(self.rail, self.number_of_agents, agents_hints, self.num_resets)
self.agents = EnvAgent.from_schedule(schedule)
line = self.line_generator(self.rail, self.number_of_agents, agents_hints,
self.num_resets, self.np_random)
self.agents = EnvAgent.from_line(line)
if agents_hints and 'city_orientations' in agents_hints:
ratio_nr_agents_to_nr_cities = self.get_num_agents() / len(agents_hints['city_orientations'])
self._max_episode_steps = self.compute_max_episode_steps(
width=self.width, height=self.height,
ratio_nr_agents_to_nr_cities=ratio_nr_agents_to_nr_cities)
else:
self._max_episode_steps = self.compute_max_episode_steps(width=self.width, height=self.height)
self.agent_positions = np.zeros((self.height, self.width), dtype=int) - 1
# Reset distance map - basically initializing
self.distance_map.reset(self.agents, self.rail)
self.restart_agents()
# NEW : Time Schedule Generation
timetable = timetable_generator(self.agents, self.distance_map,
agents_hints, self.np_random)
if activate_agents:
for i_agent in range(self.get_num_agents()):
self.set_agent_active(i_agent)
self._max_episode_steps = timetable.max_episode_steps
for agent in self.agents:
# Induce malfunctions
self._break_agent(self.mean_malfunction_rate, agent)
if agent.malfunction_data["malfunction"] > 0:
agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING
for agent_i, agent in enumerate(self.agents):
agent.earliest_departure = timetable.earliest_departures[agent_i]
agent.latest_arrival = timetable.latest_arrivals[agent_i]
else:
self.distance_map.reset(self.agents, self.rail)
# Fix agents that finished their malfunction
self._fix_agent_after_malfunction(agent)
# Reset agents to initial states
self.reset_agents()
self.num_resets += 1
self._elapsed_steps = 0
# TODO perhaps dones should be part of each agent.
# Agent positions map
self.agent_positions = np.zeros((self.height, self.width), dtype=int) - 1
self._update_agent_positions_map(ignore_old_positions=False)
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
# Reset the state of the observation builder with the new environment
self.obs_builder.reset()
self.distance_map.reset(self.agents, self.rail)
info_dict: Dict = {
'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)},
'malfunction': {
i: agent.malfunction_data['malfunction'] for i, agent in enumerate(self.agents)
},
'speed': {i: agent.speed_data['speed'] for i, agent in enumerate(self.agents)},
'status': {i: agent.status for i, agent in enumerate(self.agents)}
}
# Empty the episode store of agent positions
self.cur_episode = []
info_dict = self.get_info_dict()
# Return the new observation vectors for each agent
observation_dict: Dict = self._get_observations()
if hasattr(self, "renderer") and self.renderer is not None:
self.renderer = None
return observation_dict, info_dict
def _fix_agent_after_malfunction(self, agent: EnvAgent):
"""
Updates agent malfunction variables and fixes broken agents
Parameters
----------
agent
"""
# Ignore agents that are OK
if self._is_agent_ok(agent):
return
def _update_agent_positions_map(self, ignore_old_positions=True):
""" Update the agent_positions array for agents that changed positions """
for agent in self.agents:
if not ignore_old_positions or agent.old_position != agent.position:
if agent.position is not None:
self.agent_positions[agent.position] = agent.handle
if agent.old_position is not None:
self.agent_positions[agent.old_position] = -1
# Reduce number of malfunction steps left
if agent.malfunction_data['malfunction'] > 1:
agent.malfunction_data['malfunction'] -= 1
return
def generate_state_transition_signals(self, agent, preprocessed_action, movement_allowed):
""" Generate State Transitions Signals used in the state machine """
st_signals = StateTransitionSignals()
# Restart agents at the end of their malfunction
agent.malfunction_data['malfunction'] -= 1
if 'moving_before_malfunction' in agent.malfunction_data:
agent.moving = agent.malfunction_data['moving_before_malfunction']
return
# Malfunction starts when in_malfunction is set to true
st_signals.in_malfunction = agent.malfunction_handler.in_malfunction
def _break_agent(self, rate: float, agent) -> bool:
"""
Malfunction generator that breaks agents at a given rate.
# Malfunction counter complete - Malfunction ends next timestep
st_signals.malfunction_counter_complete = agent.malfunction_handler.malfunction_counter_complete
Parameters
----------
agent
# Earliest departure reached - Train is allowed to move now
st_signals.earliest_departure_reached = self._elapsed_steps >= agent.earliest_departure
"""
if agent.malfunction_data['malfunction'] < 1:
if self.np_random.rand() < self._malfunction_prob(rate):
num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken,
self.max_number_of_steps_broken + 1) + 1
agent.malfunction_data['malfunction'] = num_broken_steps
agent.malfunction_data['moving_before_malfunction'] = agent.moving
agent.malfunction_data['nr_malfunctions'] += 1
return
# Stop Action Given
st_signals.stop_action_given = (preprocessed_action == RailEnvActions.STOP_MOVING)
def step(self, action_dict_: Dict[int, RailEnvActions]):
"""
Updates rewards for the agents at a step.
Parameters
----------
action_dict_ : Dict[int,RailEnvActions]
"""
self._elapsed_steps += 1
# If we're done, set reward and info_dict and step() is done.
if self.dones["__all__"]:
self.rewards_dict = {}
info_dict = {
"action_required": {},
"malfunction": {},
"speed": {},
"status": {},
}
for i_agent, agent in enumerate(self.agents):
self.rewards_dict[i_agent] = self.global_reward
info_dict["action_required"][i_agent] = False
info_dict["malfunction"][i_agent] = 0
info_dict["speed"][i_agent] = 0
info_dict["status"][i_agent] = agent.status
# Valid Movement action Given
st_signals.valid_movement_action_given = preprocessed_action.is_moving_action() and movement_allowed
return self._get_observations(), self.rewards_dict, self.dones, info_dict
# Target Reached
st_signals.target_reached = fast_position_equal(agent.position, agent.target)
# Reset the step rewards
self.rewards_dict = dict()
info_dict = {
"action_required": {},
"malfunction": {},
"speed": {},
"status": {},
}
have_all_agents_ended = True # boolean flag to check if all agents are done
# Movement conflict - Multiple trains trying to move into same cell
# If speed counter is not in cell exit, the train can enter the cell
st_signals.movement_conflict = (not movement_allowed) and agent.speed_counter.is_cell_exit
for i_agent, agent in enumerate(self.agents):
# Reset the step rewards
self.rewards_dict[i_agent] = 0
return st_signals
# Induce malfunction before we do a step, thus a broken agent can't move in this step
self._break_agent(self.mean_malfunction_rate, agent)
def _handle_end_reward(self, agent: EnvAgent) -> int:
'''
Handles end-of-episode reward for a particular agent.
# Perform step on the agent
self._step_agent(i_agent, action_dict_.get(i_agent))
Parameters
----------
agent : EnvAgent
'''
reward = None
# agent done? (arrival_time is not None)
if agent.state == TrainState.DONE:
# if agent arrived earlier or on time = 0
# if agent arrived later = -ve reward based on how late
reward = min(agent.latest_arrival - agent.arrival_time, 0)
# Agents not done (arrival_time is None)
else:
# CANCELLED check (never departed)
if (agent.state.is_off_map_state()):
reward = -1 * self.cancellation_factor * \
(agent.get_travel_time_on_shortest_path(self.distance_map) + self.cancellation_time_buffer)
# manage the boolean flag to check if all agents are indeed done (or done_removed)
have_all_agents_ended &= (agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED])
# Departed but never reached
if (agent.state.is_on_map_state()):
reward = agent.get_current_delay(self._elapsed_steps, self.distance_map)
# Build info dict
info_dict["action_required"][i_agent] = self.action_required(agent)
info_dict["malfunction"][i_agent] = agent.malfunction_data['malfunction']
info_dict["speed"][i_agent] = agent.speed_data['speed']
info_dict["status"][i_agent] = agent.status
return reward
# Fix agents that finished their malfunction such that they can perform an action in the next step
self._fix_agent_after_malfunction(agent)
def preprocess_action(self, action, agent):
"""
Preprocess the provided action
* Change to DO_NOTHING if illegal action
* Block all actions when in waiting state
* Check MOVE_LEFT/MOVE_RIGHT actions on current position else try MOVE_FORWARD
"""
action = action_preprocessing.preprocess_raw_action(action, agent.state, agent.action_saver.saved_action)
action = action_preprocessing.preprocess_action_when_waiting(action, agent.state)
# Check for end of episode + set global reward to all rewards!
if have_all_agents_ended:
self.dones["__all__"] = True
self.rewards_dict = {i: self.global_reward for i in range(self.get_num_agents())}
if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps):
self.dones["__all__"] = True
for i_agent in range(self.get_num_agents()):
self.dones[i_agent] = True
# Try moving actions on current position
current_position, current_direction = agent.position, agent.direction
if current_position is None: # Agent not added on map yet
current_position, current_direction = agent.initial_position, agent.initial_direction
return self._get_observations(), self.rewards_dict, self.dones, info_dict
action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction)
def _step_agent(self, i_agent, action: Optional[RailEnvActions] = None):
"""
Performs a step and step, start and stop penalty on a single agent in the following sub steps:
- malfunction
- action handling if at the beginning of cell
- movement
# Check transitions, bounts for executing the action in the given position and directon
if action.is_moving_action() and not check_valid_action(action, self.rail, current_position, current_direction):
action = RailEnvActions.STOP_MOVING
Parameters
----------
i_agent : int
action_dict_ : Dict[int,RailEnvActions]
"""
agent = self.agents[i_agent]
if agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]: # this agent has already completed...
return
# agent gets active by a MOVE_* action and if c
if agent.status == RailAgentStatus.READY_TO_DEPART:
if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT,
RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position):
agent.status = RailAgentStatus.ACTIVE
self._set_agent_to_initial_position(agent, agent.initial_position)
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
return
else:
# TODO: Here we need to check for the departure time in future releases with full schedules
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
return
agent.old_direction = agent.direction
agent.old_position = agent.position
# if agent is broken, actions are ignored and agent does not move.
# full step penalty in this case
if agent.malfunction_data['malfunction'] > 0:
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
return
# Is the agent at the beginning of the cell? Then, it can take an action.
# As long as the agent is malfunctioning or stopped at the beginning of the cell,
# different actions may be taken!
if np.isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03):
# No action has been supplied for this agent -> set DO_NOTHING as default
if action is None:
action = RailEnvActions.DO_NOTHING
if action < 0 or action > len(RailEnvActions):
print('ERROR: illegal action=', action,
'for agent with index=', i_agent,
'"DO NOTHING" will be executed instead')
action = RailEnvActions.DO_NOTHING
if action == RailEnvActions.DO_NOTHING and agent.moving:
# Keep moving
action = RailEnvActions.MOVE_FORWARD
if action == RailEnvActions.STOP_MOVING and agent.moving:
# Only allow halting an agent on entering new cells.
agent.moving = False
self.rewards_dict[i_agent] += self.stop_penalty
if not agent.moving and not (
action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
# Allow agent to start with any forward or direction action
agent.moving = True
self.rewards_dict[i_agent] += self.start_penalty
# Store the action if action is moving
# If not moving, the action will be stored when the agent starts moving again.
if agent.moving:
_action_stored = False
_, new_cell_valid, new_direction, new_position, transition_valid = \
self._check_action_on_agent(action, agent)
if all([new_cell_valid, transition_valid]):
agent.speed_data['transition_action_on_cellexit'] = action
_action_stored = True
else:
# But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
# try to keep moving forward!
if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT):
_, new_cell_valid, new_direction, new_position, transition_valid = \
self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)
if all([new_cell_valid, transition_valid]):
agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
_action_stored = True
if not _action_stored:
# If the agent cannot move due to an invalid transition, we set its state to not moving
self.rewards_dict[i_agent] += self.invalid_action_penalty
self.rewards_dict[i_agent] += self.stop_penalty
agent.moving = False
# Now perform a movement.
# If agent.moving, increment the position_fraction by the speed of the agent
# If the new position fraction is >= 1, reset to 0, and perform the stored
# transition_action_on_cellexit if the cell is free.
if agent.moving:
agent.speed_data['position_fraction'] += agent.speed_data['speed']
if agent.speed_data['position_fraction'] >= 1.0:
# Perform stored action to transition to the next cell as soon as cell is free
# Notice that we've already checked new_cell_valid and transition valid when we stored the action,
# so we only have to check cell_free now!
# cell and transition validity was checked when we stored transition_action_on_cellexit!
cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent(
agent.speed_data['transition_action_on_cellexit'], agent)
# N.B. validity of new_cell and transition should have been verified before the action was stored!
assert new_cell_valid
assert transition_valid
if cell_free:
self._move_agent_to_new_position(agent, new_position)
agent.direction = new_direction
agent.speed_data['position_fraction'] = 0.0
# has the agent reached its target?
if np.equal(agent.position, agent.target).all():
agent.status = RailAgentStatus.DONE
self.dones[i_agent] = True
self.active_agents.remove(i_agent)
agent.moving = False
self._remove_agent_from_scene(agent)
else:
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
else:
# step penalty if not moving (stopped now or before)
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
return action
def _set_agent_to_initial_position(self, agent: EnvAgent, new_position: IntVector2D):
"""
Sets the agent to its initial position. Updates the agent object and the position
of the agent inside the global agent_position numpy array
def clear_rewards_dict(self):
""" Reset the rewards dictionary """
self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))}
Parameters
-------
agent: EnvAgent object
new_position: IntVector2D
def get_info_dict(self):
"""
agent.position = new_position
self.agent_positions[agent.position] = agent.handle
def _move_agent_to_new_position(self, agent: EnvAgent, new_position: IntVector2D):
Returns dictionary of infos for all agents
dict_keys : action_required -
malfunction - Counter value for malfunction > 0 means train is in malfunction
speed - Speed of the train
state - State from the trains's state machine
"""
Move the agent to the a new position. Updates the agent object and the position
of the agent inside the global agent_position numpy array
info_dict = {
'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)},
'malfunction': {
i: agent.malfunction_handler.malfunction_down_counter for i, agent in enumerate(self.agents)
},
'speed': {i: agent.speed_counter.speed for i, agent in enumerate(self.agents)},
'state': {i: agent.state for i, agent in enumerate(self.agents)}
}
return info_dict
Parameters
-------
agent: EnvAgent object
new_position: IntVector2D
def update_step_rewards(self, i_agent):
"""
agent.position = new_position
self.agent_positions[agent.old_position] = -1
self.agent_positions[agent.position] = agent.handle
def _remove_agent_from_scene(self, agent: EnvAgent):
Update the rewards dict for agent id i_agent for every timestep
"""
Remove the agent from the scene. Updates the agent object and the position
of the agent inside the global agent_position numpy array
pass
Parameters
-------
agent: EnvAgent object
def end_of_episode_update(self, have_all_agents_ended):
"""
self.agent_positions[agent.position] = -1
if self.remove_agents_at_target:
agent.position = None
agent.status = RailAgentStatus.DONE_REMOVED
def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent):
Updates made when episode ends
Parameters: have_all_agents_ended - Indicates if all agents have reached done state
"""
if have_all_agents_ended or \
( (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps)):
Parameters
----------
action : RailEnvActions
agent : EnvAgent
for i_agent, agent in enumerate(self.agents):
Returns
-------
bool
Is it a legal move?
1) transition allows the new_direction in the cell,
2) the new cell is not empty (case 0),
3) the cell is free, i.e., no agent is currently in that cell
"""
# compute number of possible transitions in the current
# cell used to check for invalid actions
new_direction, transition_valid = self.check_action(agent, action)
new_position = get_new_position(agent.position, new_direction)
new_cell_valid = (
np.array_equal( # Check the new position is still in the grid
new_position,
np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
and # check the new position has some transitions (ie is not an empty cell)
self.rail.get_full_transitions(*new_position) > 0)
# If transition validity hasn't been checked yet.
if transition_valid is None:
transition_valid = self.rail.get_transition(
(*agent.position, agent.direction),
new_direction)
# only call cell_free() if new cell is inside the scene
if new_cell_valid:
# Check the new position is not the same as any of the existing agent positions
# (including itself, for simplicity, since it is moving)
cell_free = self.cell_free(new_position)
else:
# if new cell is outside of scene -> cell_free is False
cell_free = False
return cell_free, new_cell_valid, new_direction, new_position, transition_valid
reward = self._handle_end_reward(agent)
self.rewards_dict[i_agent] += reward
def cell_free(self, position: IntVector2D) -> bool:
"""
Utility to check if a cell is free
self.dones[i_agent] = True
Parameters:
--------
position : Tuple[int, int]
self.dones["__all__"] = True
Returns
-------
bool
is the cell free or not?
def handle_done_state(self, agent):
""" Any updates to agent to be made in Done state """
if agent.state == TrainState.DONE and agent.arrival_time is None:
agent.arrival_time = self._elapsed_steps
self.dones[agent.handle] = True
if self.remove_agents_at_target:
agent.position = None
def step(self, action_dict_: Dict[int, RailEnvActions]):
"""
return self.agent_positions[position] == -1
def check_action(self, agent: EnvAgent, action: RailEnvActions):
Updates rewards for the agents at a step.
"""
self._elapsed_steps += 1
Parameters
----------
agent : EnvAgent
action : RailEnvActions
# Not allowed to step further once done
if self.dones["__all__"]:
raise Exception("Episode is done, cannot call step()")
Returns
-------
Tuple[Grid4TransitionsEnum,Tuple[int,int]]
self.clear_rewards_dict()
have_all_agents_ended = True # Boolean flag to check if all agents are done
self.motionCheck = ac.MotionCheck() # reset the motion check
"""
transition_valid = None
possible_transitions = self.rail.get_transitions(*agent.position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
temp_transition_data = {}
new_direction = agent.direction
if action == RailEnvActions.MOVE_LEFT:
new_direction = agent.direction - 1
if num_transitions <= 1:
transition_valid = False
for agent in self.agents:
i_agent = agent.handle
agent.old_position = agent.position
agent.old_direction = agent.direction
# Generate malfunction
agent.malfunction_handler.generate_malfunction(self.malfunction_generator, self.np_random)
# Get action for the agent
action = action_dict_.get(i_agent, RailEnvActions.DO_NOTHING)
preprocessed_action = self.preprocess_action(action, agent)
# Save moving actions in not already saved
agent.action_saver.save_action_if_allowed(preprocessed_action, agent.state)
# Train's next position can change if train is at cell's exit and train is not in malfunction
position_update_allowed = agent.speed_counter.is_cell_exit and \
not agent.malfunction_handler.malfunction_down_counter > 0 and \
not preprocessed_action == RailEnvActions.STOP_MOVING
# Calculate new position
# Keep agent in same place if already done
if agent.state == TrainState.DONE:
new_position, new_direction = agent.position, agent.direction
# Add agent to the map if not on it yet
elif agent.position is None and agent.action_saver.is_action_saved:
new_position = agent.initial_position
new_direction = agent.initial_direction
# If movement is allowed apply saved action independent of other agents
elif agent.action_saver.is_action_saved and position_update_allowed:
saved_action = agent.action_saver.saved_action
# Apply action independent of other agents and get temporary new position and direction
new_position, new_direction = env_utils.apply_action_independent(saved_action,
self.rail,
agent.position,
agent.direction)
preprocessed_action = saved_action
else:
new_position, new_direction = agent.position, agent.direction
elif action == RailEnvActions.MOVE_RIGHT:
new_direction = agent.direction + 1
if num_transitions <= 1:
transition_valid = False
temp_transition_data[i_agent] = env_utils.AgentTransitionData(position=new_position,
direction=new_direction,
preprocessed_action=preprocessed_action)
new_direction %= 4
# This is for storing and later checking for conflicts of agents trying to occupy same cell
self.motionCheck.addAgent(i_agent, agent.position, new_position)
if action == RailEnvActions.MOVE_FORWARD and num_transitions == 1:
# - dead-end, straight line or curved line;
# new_direction will be the only valid transition
# - take only available transition
new_direction = np.argmax(possible_transitions)
transition_valid = True
return new_direction, transition_valid
# Find conflicts between trains trying to occupy same cell
self.motionCheck.find_conflicts()
def _get_observations(self):
"""
Utility which returns the observations for an agent with respect to environment
for agent in self.agents:
i_agent = agent.handle
Returns
------
Dict object
"""
self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
return self.obs_dict
## Update positions
if agent.malfunction_handler.in_malfunction:
movement_allowed = False
else:
movement_allowed = self.motionCheck.check_motion(i_agent, agent.position)
def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]:
"""
Returns directions in which the agent can move
movement_inside_cell = agent.state == TrainState.STOPPED and not agent.speed_counter.is_cell_exit
movement_allowed = movement_allowed or movement_inside_cell
Parameters:
---------
row : int
col : int
# Fetch the saved transition data
agent_transition_data = temp_transition_data[i_agent]
preprocessed_action = agent_transition_data.preprocessed_action
Returns:
-------
List[int]
"""
return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col))
## Update states
state_transition_signals = self.generate_state_transition_signals(agent, preprocessed_action, movement_allowed)
agent.state_machine.set_transition_signals(state_transition_signals)
agent.state_machine.step()
def get_full_state_msg(self):
"""
Returns state of environment in msgpack object
"""
grid_data = self.rail.grid.tolist()
agent_data = [agent.to_agent() for agent in self.agents]
msgpack.packb(grid_data, use_bin_type=True)
msgpack.packb(agent_data, use_bin_type=True)
msg_data = {
"grid": grid_data,
"agents": agent_data}
return msgpack.packb(msg_data, use_bin_type=True)
# Needed when not removing agents at target
movement_allowed = movement_allowed and agent.state != TrainState.DONE
def get_agent_state_msg(self):
"""
Returns agents information in msgpack object
"""
agent_data = [agent.to_agent() for agent in self.agents]
msg_data = {
"agents": agent_data}
return msgpack.packb(msg_data, use_bin_type=True)
# Agent is being added to map
if agent.state.is_on_map_state():
if agent.state_machine.previous_state.is_off_map_state():
agent.position = agent.initial_position
agent.direction = agent.initial_direction
# Speed counter completes
elif movement_allowed and (agent.speed_counter.is_cell_exit):
agent.position = agent_transition_data.position
agent.direction = agent_transition_data.direction
agent.state_machine.update_if_reached(agent.position, agent.target)
def set_full_state_msg(self, msg_data):
"""
Sets environment state with msgdata object passed as argument
# Off map or on map state and position should match
env_utils.state_position_sync_check(agent.state, agent.position, agent.handle)
Parameters
-------
msg_data: msgpack object
"""
data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
self.rail.grid = np.array(data["grid"])
# agents are always reset as not moving
self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
# setup with loaded data
self.height, self.width = self.rail.grid.shape
self.rail.height = self.height
self.rail.width = self.width
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
# Handle done state actions, optionally remove agents
self.handle_done_state(agent)
def set_full_state_dist_msg(self, msg_data):
"""
Sets environment grid state and distance map with msgdata object passed as argument
have_all_agents_ended &= (agent.state == TrainState.DONE)
Parameters
-------
msg_data: msgpack object
"""
data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
self.rail.grid = np.array(data["grid"])
# agents are always reset as not moving
self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
if "distance_map" in data.keys():
self.distance_map.set(data["distance_map"])
# setup with loaded data
self.height, self.width = self.rail.grid.shape
self.rail.height = self.height
self.rail.width = self.width
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
## Update rewards
self.update_step_rewards(i_agent)
def get_full_state_dist_msg(self):
"""
Returns environment information with distance map information as msgpack object
"""
grid_data = self.rail.grid.tolist()
agent_data = [agent.to_agent() for agent in self.agents]
msgpack.packb(grid_data, use_bin_type=True)
msgpack.packb(agent_data, use_bin_type=True)
distance_map_data = self.distance_map.get()
msgpack.packb(distance_map_data, use_bin_type=True)
msg_data = {
"grid": grid_data,
"agents": agent_data,
"distance_map": distance_map_data}
## Update counters (malfunction and speed)
agent.speed_counter.update_counter(agent.state, agent.old_position)
# agent.state_machine.previous_state)
agent.malfunction_handler.update_counter()
return msgpack.packb(msg_data, use_bin_type=True)
# Clear old action when starting in new cell
if agent.speed_counter.is_cell_entry and agent.position is not None:
agent.action_saver.clear_saved_action()
def save(self, filename, save_distance_maps=False):
"""
Saves environment and distance map information in a file
# Check if episode has ended and update rewards and dones
self.end_of_episode_update(have_all_agents_ended)
Parameters:
---------
filename: string
save_distance_maps: bool
"""
if save_distance_maps is True:
if self.distance_map.get() is not None:
if len(self.distance_map.get()) > 0:
with open(filename, "wb") as file_out:
file_out.write(self.get_full_state_dist_msg())
else:
print("[WARNING] Unable to save the distance map for this environment, as none was found !")
self._update_agent_positions_map()
else:
print("[WARNING] Unable to save the distance map for this environment, as none was found !")
return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict()
else:
with open(filename, "wb") as file_out:
file_out.write(self.get_full_state_msg())
def load(self, filename):
def record_timestep(self, dActions):
"""
Load environment with distance map from a file
Parameters:
-------
filename: string
Record the positions and orientations of all agents in memory, in the cur_episode
"""
with open(filename, "rb") as file_in:
load_data = file_in.read()
self.set_full_state_dist_msg(load_data)
list_agents_state = []
for i_agent in range(self.get_num_agents()):
agent = self.agents[i_agent]
# the int cast is to avoid numpy types which may cause problems with msgpack
# in env v2, agents may have position None, before starting
if agent.position is None:
pos = (0, 0)
else:
pos = (int(agent.position[0]), int(agent.position[1]))
# print("pos:", pos, type(pos[0]))
list_agents_state.append([
*pos, int(agent.direction),
agent.malfunction_handler.malfunction_down_counter,
int(agent.status),
int(agent.position in self.motionCheck.svDeadlocked)
])
self.cur_episode.append(list_agents_state)
self.list_actions.append(dActions)
def load_pkl(self, pkl_data):
def _get_observations(self):
"""
Load environment with distance map from a pickle file
Parameters:
-------
pkl_data: pickle file
Utility which returns the dictionary of observations for an agent with respect to environment
"""
self.set_full_state_msg(pkl_data)
# print(f"_get_obs - num agents: {self.get_num_agents()} {list(range(self.get_num_agents()))}")
self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
return self.obs_dict
def load_resource(self, package, resource):
def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]:
"""
Load environment with distance map from a binary
Returns directions in which the agent can move
"""
from importlib_resources import read_binary
load_data = read_binary(package, resource)
self.set_full_state_msg(load_data)
return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col))
def _exp_distirbution_synced(self, rate: float) -> float:
"""
......@@ -950,17 +678,6 @@ class RailEnv(Environment):
x = - np.log(1 - u) * rate
return x
def _malfunction_prob(self, rate: float) -> float:
"""
Probability of a single agent to break. According to Poisson process with given rate
:param rate:
:return:
"""
if rate <= 0:
return 0.
else:
return 1 - np.exp(- (1 / rate))
def _is_agent_ok(self, agent: EnvAgent) -> bool:
"""
Check if an agent is ok, meaning it can move and is not malfuncitoinig
......@@ -973,4 +690,84 @@ class RailEnv(Environment):
True if agent is ok, False otherwise
"""
return agent.malfunction_data['malfunction'] < 1
return agent.malfunction_handler.in_malfunction
def save(self, filename):
print("DEPRECATED call to env.save() - pls call RailEnvPersister.save()")
persistence.RailEnvPersister.save(self, filename)
def render(self, mode="rgb_array", gl="PGL", agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False, clear_debug_text=True, show=False,
screen_height=600, screen_width=800,
show_observations=False, show_predictions=False,
show_rowcols=False, return_image=True):
"""
This methods provides the option to render the
environment's behavior as an image or to a window.
Parameters
----------
mode
Returns
-------
Image if mode is rgb_array, opens a window otherwise
"""
if not hasattr(self, "renderer") or self.renderer is None:
self.initialize_renderer(mode=mode, gl=gl, # gl="TKPILSVG",
agent_render_variant=agent_render_variant,
show_debug=show_debug,
clear_debug_text=clear_debug_text,
show=show,
screen_height=screen_height, # Adjust these parameters to fit your resolution
screen_width=screen_width)
return self.update_renderer(mode=mode, show=show, show_observations=show_observations,
show_predictions=show_predictions,
show_rowcols=show_rowcols, return_image=return_image)
def initialize_renderer(self, mode, gl,
agent_render_variant,
show_debug,
clear_debug_text,
show,
screen_height,
screen_width):
# Initiate the renderer
self.renderer = RenderTool(self, gl=gl, # gl="TKPILSVG",
agent_render_variant=agent_render_variant,
show_debug=show_debug,
clear_debug_text=clear_debug_text,
screen_height=screen_height, # Adjust these parameters to fit your resolution
screen_width=screen_width) # Adjust these parameters to fit your resolution
self.renderer.show = show
self.renderer.reset()
def update_renderer(self, mode, show, show_observations, show_predictions,
show_rowcols, return_image):
"""
This method updates the render.
Parameters
----------
mode
Returns
-------
Image if mode is rgb_array, None otherwise
"""
image = self.renderer.render_env(show=show, show_observations=show_observations,
show_predictions=show_predictions,
show_rowcols=show_rowcols, return_image=return_image)
if mode == 'rgb_array':
return image[:, :, :3]
def close(self):
"""
This methods closes any renderer window.
"""
if hasattr(self, "renderer") and self.renderer is not None:
try:
if self.renderer.show:
self.renderer.close_window()
except Exception as e:
print("Could Not close window due to:",e)
self.renderer = None
from enum import IntEnum
from typing import NamedTuple
from flatland.core.grid.grid4 import Grid4TransitionsEnum
class RailEnvActions(IntEnum):
DO_NOTHING = 0 # implies change of direction in a dead-end!
MOVE_LEFT = 1
MOVE_FORWARD = 2
MOVE_RIGHT = 3
STOP_MOVING = 4
@staticmethod
def to_char(a: int):
return {
0: 'B',
1: 'L',
2: 'F',
3: 'R',
4: 'S',
}[a]
@classmethod
def is_action_valid(cls, action):
return action in cls._value2member_map_
def is_moving_action(self):
return self.value in [self.MOVE_RIGHT, self.MOVE_LEFT, self.MOVE_FORWARD]
RailEnvGridPos = NamedTuple('RailEnvGridPos', [('r', int), ('c', int)])
RailEnvNextAction = NamedTuple('RailEnvNextAction', [('action', RailEnvActions), ('next_position', RailEnvGridPos),
('next_direction', Grid4TransitionsEnum)])
import math
from typing import Dict, List, Optional, NamedTuple, Tuple, Set
from typing import Dict, List, Optional, Tuple, Set
import matplotlib.pyplot as plt
import numpy as np
......@@ -7,15 +7,13 @@ import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.step_utils.states import TrainState
from flatland.envs.distance_map import DistanceMap
from flatland.envs.rail_env import RailEnvNextAction, RailEnvActions
from flatland.envs.fast_methods import fast_count_nonzero
from flatland.envs.rail_env_action import RailEnvActions, RailEnvNextAction
from flatland.envs.rail_trainrun_data_structures import Waypoint
from flatland.utils.ordered_set import OrderedSet
WalkingElement = \
NamedTuple('WalkingElement',
[('position', Tuple[int, int]), ('direction', int), ('next_action_element', RailEnvNextAction)])
def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
agent_position: Tuple[int, int],
......@@ -23,6 +21,9 @@ def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
"""
Get the valid move actions (forward, left, right) for an agent.
TODO https://gitlab.aicrowd.com/flatland/flatland/issues/299 The implementation could probably be more efficient
and more elegant. But given the few calls this has no priority now.
Parameters
----------
agent_direction : Grid4TransitionsEnum
......@@ -36,9 +37,9 @@ def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
Possible move actions (forward,left,right) and the next position/direction they lead to.
It is not checked that the next cell is free.
"""
valid_actions: Set[RailEnvNextAction] = OrderedSet()
valid_actions: Set[RailEnvNextAction] = []
possible_transitions = rail.get_transitions(*agent_position, agent_direction)
num_transitions = np.count_nonzero(possible_transitions)
num_transitions = fast_count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right], relative to the current orientation
# If only one transition is possible, the forward branch is aligned with it.
......@@ -47,13 +48,13 @@ def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
exit_direction = (agent_direction + 2) % 4
if possible_transitions[exit_direction]:
new_position = get_new_position(agent_position, exit_direction)
valid_actions.add(RailEnvNextAction(action, new_position, exit_direction))
valid_actions = [(RailEnvNextAction(action, new_position, exit_direction))]
elif num_transitions == 1:
action = RailEnvActions.MOVE_FORWARD
for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[new_direction]:
new_position = get_new_position(agent_position, new_direction)
valid_actions.add(RailEnvNextAction(action, new_position, new_direction))
valid_actions = [(RailEnvNextAction(action, new_position, new_direction))]
else:
for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[new_direction]:
......@@ -67,13 +68,141 @@ def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
raise Exception("Illegal state")
new_position = get_new_position(agent_position, new_direction)
valid_actions.add(RailEnvNextAction(action, new_position, new_direction))
valid_actions.append(RailEnvNextAction(action, new_position, new_direction))
return valid_actions
def get_new_position_for_action(
agent_position: Tuple[int, int],
agent_direction: Grid4TransitionsEnum,
action: RailEnvActions,
rail: GridTransitionMap) -> Tuple[int, int, int]:
"""
Get the next position for this action.
TODO https://gitlab.aicrowd.com/flatland/flatland/issues/299 The implementation could probably be more efficient
and more elegant. But given the few calls this has no priority now.
Parameters
----------
agent_position
agent_direction
action
rail
Returns
-------
Tuple[int,int,int]
row, column, direction
"""
possible_transitions = rail.get_transitions(*agent_position, agent_direction)
num_transitions = fast_count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right], relative to the current orientation
# If only one transition is possible, the forward branch is aligned with it.
if rail.is_dead_end(agent_position):
valid_action = RailEnvActions.MOVE_FORWARD
exit_direction = (agent_direction + 2) % 4
if possible_transitions[exit_direction]:
new_position = get_new_position(agent_position, exit_direction)
if valid_action == action:
return new_position, exit_direction
elif num_transitions == 1:
valid_action = RailEnvActions.MOVE_FORWARD
for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[new_direction]:
new_position = get_new_position(agent_position, new_direction)
if valid_action == action:
return new_position, new_direction
else:
for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[new_direction]:
if new_direction == agent_direction:
valid_action = RailEnvActions.MOVE_FORWARD
if valid_action == action:
new_position = get_new_position(agent_position, new_direction)
return new_position, new_direction
elif new_direction == (agent_direction + 1) % 4:
valid_action = RailEnvActions.MOVE_RIGHT
if valid_action == action:
new_position = get_new_position(agent_position, new_direction)
return new_position, new_direction
elif new_direction == (agent_direction - 1) % 4:
valid_action = RailEnvActions.MOVE_LEFT
if valid_action == action:
new_position = get_new_position(agent_position, new_direction)
return new_position, new_direction
def get_action_for_move(
agent_position: Tuple[int, int],
agent_direction: Grid4TransitionsEnum,
next_agent_position: Tuple[int, int],
next_agent_direction: int,
rail: GridTransitionMap) -> Optional[RailEnvActions]:
"""
Get the action (if any) to move from a position and direction to another.
TODO https://gitlab.aicrowd.com/flatland/flatland/issues/299 The implementation could probably be more efficient
and more elegant. But given the few calls this has no priority now.
Parameters
----------
agent_position
agent_direction
next_agent_position
next_agent_direction
rail
Returns
-------
Optional[RailEnvActions]
the action (if direct transition possible) or None.
"""
possible_transitions = rail.get_transitions(*agent_position, agent_direction)
num_transitions = fast_count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right], relative to the current orientation
# If only one transition is possible, the forward branch is aligned with it.
if rail.is_dead_end(agent_position):
valid_action = RailEnvActions.MOVE_FORWARD
new_direction = (agent_direction + 2) % 4
if possible_transitions[new_direction]:
new_position = get_new_position(agent_position, new_direction)
if new_position == next_agent_position and new_direction == next_agent_direction:
return valid_action
elif num_transitions == 1:
valid_action = RailEnvActions.MOVE_FORWARD
for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[new_direction]:
new_position = get_new_position(agent_position, new_direction)
if new_position == next_agent_position and new_direction == next_agent_direction:
return valid_action
else:
for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[new_direction]:
if new_direction == agent_direction:
valid_action = RailEnvActions.MOVE_FORWARD
new_position = get_new_position(agent_position, new_direction)
if new_position == next_agent_position and new_direction == next_agent_direction:
return valid_action
elif new_direction == (agent_direction + 1) % 4:
valid_action = RailEnvActions.MOVE_RIGHT
new_position = get_new_position(agent_position, new_direction)
if new_position == next_agent_position and new_direction == next_agent_direction:
return valid_action
elif new_direction == (agent_direction - 1) % 4:
valid_action = RailEnvActions.MOVE_LEFT
new_position = get_new_position(agent_position, new_direction)
if new_position == next_agent_position and new_direction == next_agent_direction:
return valid_action
# N.B. get_shortest_paths is not part of distance_map since it refers to RailEnvActions (would lead to circularity!)
def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = None, agent_handle: Optional[int] = None) \
-> Dict[int, Optional[List[WalkingElement]]]:
-> Dict[int, Optional[List[Waypoint]]]:
"""
Computes the shortest path for each agent to its target and the action to be taken to do so.
The paths are derived from a `DistanceMap`.
......@@ -99,11 +228,11 @@ def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = Non
shortest_paths = dict()
def _shortest_path_for_agent(agent):
if agent.status == RailAgentStatus.READY_TO_DEPART:
if agent.state.is_off_map_state():
position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
elif agent.state.is_on_map_state():
position = agent.position
elif agent.status == RailAgentStatus.DONE:
elif agent.state == TrainState.DONE:
position = agent.target
else:
shortest_paths[agent.handle] = None
......@@ -123,7 +252,7 @@ def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = Non
best_next_action = next_action
distance = next_action_distance
shortest_paths[agent.handle].append(WalkingElement(position, direction, best_next_action))
shortest_paths[agent.handle].append(Waypoint(position, direction))
depth += 1
# if there is no way to continue, the rail must be disconnected!
......@@ -135,9 +264,7 @@ def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = Non
position = best_next_action.next_position
direction = best_next_action.next_direction
if max_depth is None or depth < max_depth:
shortest_paths[agent.handle].append(
WalkingElement(position, direction,
RailEnvNextAction(RailEnvActions.STOP_MOVING, position, direction)))
shortest_paths[agent.handle].append(Waypoint(position, direction))
if agent_handle is not None:
_shortest_path_for_agent(distance_map.agents[agent_handle])
......@@ -148,6 +275,106 @@ def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = Non
return shortest_paths
def get_k_shortest_paths(env,
source_position: Tuple[int, int],
source_direction: int,
target_position=Tuple[int, int],
k: int = 1, debug=False) -> List[Tuple[Waypoint]]:
"""
Computes the k shortest paths using modified Dijkstra
following pseudo-code https://en.wikipedia.org/wiki/K_shortest_path_routing
In contrast to the pseudo-code in wikipedia, we do not a allow for loopy paths.
Parameters
----------
env : RailEnv
source_position: Tuple[int,int]
source_direction: int
target_position: Tuple[int,int]
k : int
max number of shortest paths
debug: bool
print debug statements
Returns
-------
List[Tuple[WalkingElement]]
We use tuples since we need the path elements to be hashable.
We use a list of paths in order to keep the order of length.
"""
# P: set of shortest paths from s to t
# P =empty,
shortest_paths: List[Tuple[Waypoint]] = []
# countu: number of shortest paths found to node u
# countu = 0, for all u in V
count = {(r, c, d): 0 for r in range(env.height) for c in range(env.width) for d in range(4)}
# B is a heap data structure containing paths
# N.B. use OrderedSet to make result deterministic!
heap: OrderedSet[Tuple[Waypoint]] = OrderedSet()
# insert path Ps = {s} into B with cost 0
heap.add((Waypoint(source_position, source_direction),))
# while B is not empty and countt < K:
while len(heap) > 0 and len(shortest_paths) < k:
if debug:
print("iteration heap={}, shortest_paths={}".format(heap, shortest_paths))
# – let Pu be the shortest cost path in B with cost C
cost = np.inf
pu = None
for path in heap:
if len(path) < cost:
pu = path
cost = len(path)
u: Waypoint = pu[-1]
if debug:
print(" looking at pu={}".format(pu))
# – B = B − {Pu }
heap.remove(pu)
# – countu = countu + 1
urcd = (*u.position, u.direction)
count[urcd] += 1
# – if u = t then P = P U {Pu}
if u.position == target_position:
if debug:
print(" found of length {} {}".format(len(pu), pu))
shortest_paths.append(pu)
# – if countu ≤ K then
# CAVEAT: do not allow for loopy paths
elif count[urcd] <= k:
possible_transitions = env.rail.get_transitions(*urcd)
if debug:
print(" looking at neighbors of u={}, transitions are {}".format(u, possible_transitions))
# for each vertex v adjacent to u:
for new_direction in range(4):
if debug:
print(" looking at new_direction={}".format(new_direction))
if possible_transitions[new_direction]:
new_position = get_new_position(u.position, new_direction)
if debug:
print(" looking at neighbor v={}".format((*new_position, new_direction)))
v = Waypoint(position=new_position, direction=new_direction)
# CAVEAT: do not allow for loopy paths
if v in pu:
continue
# – let Pv be a new path with cost C + w(u, v) formed by concatenating edge (u, v) to path Pu
pv = pu + (v,)
# – insert Pv into B
heap.add(pv)
# return P
return shortest_paths
def visualize_distance_map(distance_map: DistanceMap, agent_handle: int = 0):
if agent_handle >= distance_map.get().shape[0]:
print("Error: agent_handle cannot be larger than actual number of agents")
......
......@@ -3,12 +3,14 @@ from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file
from flatland.envs.schedule_generators import schedule_from_file
from flatland.envs.line_generators import line_from_file
def load_flatland_environment_from_file(file_name: str,
load_from_package: str = None,
obs_builder_object: ObservationBuilder = None) -> RailEnv:
obs_builder_object: ObservationBuilder = None,
record_steps = False,
) -> RailEnv:
"""
Parameters
----------
......@@ -30,10 +32,10 @@ def load_flatland_environment_from_file(file_name: str,
obs_builder_object = TreeObsForRailEnv(
max_depth=2,
predictor=ShortestPathPredictorForRailEnv(max_depth=10))
environment = RailEnv(width=1, # will be overridden when loading from file
height=1, # will be overridden when loading from file
rail_generator=rail_from_file(file_name, load_from_package),
number_of_agents=1, # will be overridden when loading from file
schedule_generator=schedule_from_file(file_name, load_from_package),
obs_builder_object=obs_builder_object)
environment = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name, load_from_package),
line_generator=line_from_file(file_name, load_from_package),
number_of_agents=1,
obs_builder_object=obs_builder_object,
record_steps=record_steps,
)
return environment
"""Rail generators (infrastructure manager, "Infrastrukturbetreiber")."""
import sys
import warnings
from typing import Callable, Tuple, Optional, Dict, List
import msgpack
import numpy as np
from numpy.random.mtrand import RandomState
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_direction, mirror, direction_to_point
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.grid.grid_utils import distance_on_rail, IntVector2DArray, IntVector2D, \
from flatland.core.grid.grid4_utils import direction_to_point
from flatland.core.grid.grid_utils import IntVector2DArray, IntVector2D, \
Vec2dOperations
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs import persistence
from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map, connect_straight_line_in_grid_map, \
fix_inner_nodes, align_cell_to_city
RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]]
RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
def empty_rail_generator() -> RailGenerator:
"""
Returns a generator which returns an empty rail mail with no agents.
Primarily used by the editor
"""
def generator(width: int, height: int, num_agents: int = 0, num_resets: int = 0) -> RailGeneratorProduct:
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
rail_array = grid_map.grid
rail_array.fill(0)
return grid_map, None
return generator
""" A rail generator returns a RailGenerator Product, which is just
a GridTransitionMap followed by an (optional) dict/
"""
RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
def complex_rail_generator(nr_start_goal=1,
nr_extra=100,
min_dist=20,
max_dist=99999,
seed=1) -> RailGenerator:
"""
complex_rail_generator
Parameters
----------
width : int
The width (number of cells) of the grid to generate.
height : int
The height (number of cells) of the grid to generate.
class RailGen(object):
""" Base class for RailGen(erator) replacement
Returns
-------
numpy.ndarray of type numpy.uint16
The matrix with the correct 16-bit bitmaps for each cell.
WIP to replace bare generators with classes / objects without unnamed local variables
which prevent pickling.
"""
def generator(width, height, num_agents, num_resets=0):
if num_agents > nr_start_goal:
num_agents = nr_start_goal
print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents")
grid_map = GridTransitionMap(width=width, height=height, transitions=RailEnvTransitions())
rail_array = grid_map.grid
rail_array.fill(0)
np.random.seed(seed + num_resets)
# generate rail array
# step 1:
# - generate a start and goal position
# - validate min/max distance allowed
# - validate that start/goals are not placed too close to other start/goals
# - draw a rail from [start,goal]
# - if rail crosses existing rail then validate new connection
# - possibility that this fails to create a path to goal
# - on failure generate new start/goal
#
# step 2:
# - add more rails to map randomly between cells that have rails
# - validate all new rails, on failure don't add new rails
#
# step 3:
# - return transition map + list of [start_pos, start_dir, goal_pos] points
#
rail_trans = grid_map.transitions
start_goal = []
start_dir = []
nr_created = 0
created_sanity = 0
sanity_max = 9000
while nr_created < nr_start_goal and created_sanity < sanity_max:
all_ok = False
for _ in range(sanity_max):
start = (np.random.randint(0, height), np.random.randint(0, width))
goal = (np.random.randint(0, height), np.random.randint(0, width))
# check to make sure start,goal pos is empty?
if rail_array[goal] != 0 or rail_array[start] != 0:
continue
# check min/max distance
dist_sg = distance_on_rail(start, goal)
if dist_sg < min_dist:
continue
if dist_sg > max_dist:
continue
# check distance to existing points
sg_new = [start, goal]
def check_all_dist(sg_new):
"""
Function to check the distance betweens start and goal
:param sg_new: start and goal tuple
:return: True if distance is larger than 2, False otherwise
"""
for sg in start_goal:
for i in range(2):
for j in range(2):
dist = distance_on_rail(sg_new[i], sg[j])
if dist < 2:
return False
return True
if check_all_dist(sg_new):
all_ok = True
break
def __init__(self, *args, **kwargs):
""" constructor to record any state to be reused in each "generation"
"""
pass
if not all_ok:
# we might as well give up at this point
break
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
np_random: RandomState = None) -> RailGeneratorProduct:
pass
new_path = connect_rail_in_grid_map(grid_map, start, goal, rail_trans, Vec2d.get_chebyshev_distance,
flip_start_node_trans=True, flip_end_node_trans=True,
respect_transition_validity=True, forbidden_cells=None)
if len(new_path) >= 2:
nr_created += 1
start_goal.append([start, goal])
start_dir.append(mirror(get_direction(new_path[0], new_path[1])))
else:
# after too many failures we will give up
created_sanity += 1
# add extra connections between existing rail
created_sanity = 0
nr_created = 0
while nr_created < nr_extra and created_sanity < sanity_max:
all_ok = False
for _ in range(sanity_max):
start = (np.random.randint(0, height), np.random.randint(0, width))
goal = (np.random.randint(0, height), np.random.randint(0, width))
# check to make sure start,goal pos are not empty
if rail_array[goal] == 0 or rail_array[start] == 0:
continue
else:
all_ok = True
break
if not all_ok:
break
new_path = connect_rail_in_grid_map(grid_map, start, goal, rail_trans, Vec2d.get_chebyshev_distance,
flip_start_node_trans=True, flip_end_node_trans=True,
respect_transition_validity=True, forbidden_cells=None)
def __call__(self, *args, **kwargs) -> RailGeneratorProduct:
return self.generate(*args, **kwargs)
if len(new_path) >= 2:
nr_created += 1
return grid_map, {'agents_hints': {
'start_goal': start_goal,
'start_dir': start_dir
}}
return generator
def empty_rail_generator() -> RailGenerator:
return EmptyRailGen()
def rail_from_manual_specifications_generator(rail_spec):
class EmptyRailGen(RailGen):
"""
Utility to convert a rail given by manual specification as a map of tuples
(cell_type, rotation), to a transition map with the correct 16-bit
transitions specifications.
Parameters
----------
rail_spec : list of list of tuples
List (rows) of lists (columns) of tuples, each specifying a rail_spec_of_cell for
the RailEnv environment as (cell_type, rotation), with rotation being
clock-wise and in [0, 90, 180, 270].
Returns
-------
function
Generator function that always returns a GridTransitionMap object with
the matrix of correct 16-bit bitmaps for each rail_spec_of_cell.
Returns a generator which returns an empty rail mail with no agents.
Primarily used by the editor
"""
def generator(width, height, num_agents, num_resets=0):
rail_env_transitions = RailEnvTransitions()
height = len(rail_spec)
width = len(rail_spec[0])
rail = GridTransitionMap(width=width, height=height, transitions=rail_env_transitions)
for r in range(height):
for c in range(width):
rail_spec_of_cell = rail_spec[r][c]
index_basic_type_of_cell_ = rail_spec_of_cell[0]
rotation_cell_ = rail_spec_of_cell[1]
if index_basic_type_of_cell_ < 0 or index_basic_type_of_cell_ >= len(rail_env_transitions.transitions):
print("ERROR - invalid rail_spec_of_cell type=", index_basic_type_of_cell_)
return []
basic_type_of_cell_ = rail_env_transitions.transitions[index_basic_type_of_cell_]
effective_transition_cell = rail_env_transitions.rotate_transition(basic_type_of_cell_, rotation_cell_)
rail.set_transitions((r, c), effective_transition_cell)
return [rail, None]
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
np_random: RandomState = None) -> RailGenerator:
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
rail_array = grid_map.grid
rail_array.fill(0)
return generator
return grid_map, None
def rail_from_file(filename, load_from_package=None) -> RailGenerator:
......@@ -234,21 +78,16 @@ def rail_from_file(filename, load_from_package=None) -> RailGenerator:
the matrix of correct 16-bit bitmaps for each rail_spec_of_cell.
"""
def generator(width, height, num_agents, num_resets):
def generator(width: int, height: int, num_agents: int, num_resets: int = 0,
np_random: RandomState = None) -> List:
env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package)
rail_env_transitions = RailEnvTransitions()
if load_from_package is not None:
from importlib_resources import read_binary
load_data = read_binary(load_from_package, filename)
else:
with open(filename, "rb") as file_in:
load_data = file_in.read()
data = msgpack.unpackb(load_data, use_list=False)
grid = np.array(data[b"grid"])
grid = np.array(env_dict["grid"])
rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions)
rail.grid = grid
if b"distance_map" in data.keys():
distance_map = data[b"distance_map"]
if "distance_map" in env_dict:
distance_map = env_dict["distance_map"]
if len(distance_map) > 0:
return rail, {'distance_map': distance_map}
return [rail, None]
......@@ -256,334 +95,57 @@ def rail_from_file(filename, load_from_package=None) -> RailGenerator:
return generator
def rail_from_grid_transition_map(rail_map) -> RailGenerator:
"""
Utility to convert a rail given by a GridTransitionMap map with the correct
16-bit transitions specifications.
Parameters
----------
rail_map : GridTransitionMap object
GridTransitionMap object to return when the generator is called.
Returns
-------
function
Generator function that always returns the given `rail_map` object.
"""
def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct:
return rail_map, None
return generator
def random_rail_generator(cell_type_relative_proportion=[1.0] * 11, seed=1) -> RailGenerator:
"""
Dummy random level generator:
- fill in cells at random in [width-2, height-2]
- keep filling cells in among the unfilled ones, such that all transitions\
are legit; if no cell can be filled in without violating some\
transitions, pick one among those that can satisfy most transitions\
(1,2,3 or 4), and delete (+mark to be re-filled) the cells that were\
incompatible.
- keep trying for a total number of insertions\
(e.g., (W-2)*(H-2)*MAX_REPETITIONS ); if no solution is found, empty the\
board and try again from scratch.
- finally pad the border of the map with dead-ends to avoid border issues.
Dead-ends are not allowed inside the grid, only at the border; however, if
no cell type can be inserted in a given cell (because of the neighboring
transitions), deadends are allowed if they solve the problem. This was
found to turn most un-genereatable levels into valid ones.
class RailFromGridGen(RailGen):
def __init__(self, rail_map, optionals=None):
self.rail_map = rail_map
self.optionals = optionals
Parameters
----------
width : int
The width (number of cells) of the grid to generate.
height : int
The height (number of cells) of the grid to generate.
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
np_random: RandomState = None) -> RailGeneratorProduct:
return self.rail_map, self.optionals
Returns
-------
numpy.ndarray of type numpy.uint16
The matrix with the correct 16-bit bitmaps for each cell.
"""
def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct:
np.random.seed(seed + num_resets)
t_utils = RailEnvTransitions()
transition_probability = cell_type_relative_proportion
transitions_templates_ = []
transition_probabilities = []
for i in range(len(t_utils.transitions)): # don't include dead-ends
if t_utils.transitions[i] == int('0010000000000000', 2):
continue
all_transitions = 0
for dir_ in range(4):
trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
all_transitions |= (trans[0] << 3) | \
(trans[1] << 2) | \
(trans[2] << 1) | \
(trans[3])
template = [int(x) for x in bin(all_transitions)[2:]]
template = [0] * (4 - len(template)) + template
# add all rotations
for rot in [0, 90, 180, 270]:
transitions_templates_.append((template,
t_utils.rotate_transition(
t_utils.transitions[i],
rot)))
transition_probabilities.append(transition_probability[i])
template = [template[-1]] + template[:-1]
def get_matching_templates(template):
"""
Returns a list of possible transition maps for a given template
Parameters:
------
template:List[int]
Returns:
------
List[int]
"""
ret = []
for i in range(len(transitions_templates_)):
is_match = True
for j in range(4):
if template[j] >= 0 and template[j] != transitions_templates_[i][0][j]:
is_match = False
break
if is_match:
ret.append((transitions_templates_[i][1], transition_probabilities[i]))
return ret
MAX_INSERTIONS = (width - 2) * (height - 2) * 10
MAX_ATTEMPTS_FROM_SCRATCH = 10
attempt_number = 0
while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH:
cells_to_fill = []
rail = []
for r in range(height):
rail.append([None] * width)
if r > 0 and r < height - 1:
cells_to_fill = cells_to_fill + [(r, c) for c in range(1, width - 1)]
num_insertions = 0
while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0:
cell = cells_to_fill[np.random.choice(len(cells_to_fill), 1)[0]]
cells_to_fill.remove(cell)
row = cell[0]
col = cell[1]
# look at its neighbors and see what are the possible transitions
# that can be chosen from, if any.
valid_template = [-1, -1, -1, -1]
for el in [(0, 2, (-1, 0)),
(1, 3, (0, 1)),
(2, 0, (1, 0)),
(3, 1, (0, -1))]: # N, E, S, W
neigh_trans = rail[row + el[2][0]][col + el[2][1]]
if neigh_trans is not None:
# select transition coming from facing direction el[1] and
# moving to direction el[1]
max_bit = 0
for k in range(4):
max_bit |= t_utils.get_transition(neigh_trans, k, el[1])
if max_bit:
valid_template[el[0]] = 1
else:
valid_template[el[0]] = 0
possible_cell_transitions = get_matching_templates(valid_template)
if len(possible_cell_transitions) == 0: # NO VALID TRANSITIONS
# no cell can be filled in without violating some transitions
# can a dead-end solve the problem?
if valid_template.count(1) == 1:
for k in range(4):
if valid_template[k] == 1:
rot = 0
if k == 0:
rot = 180
elif k == 1:
rot = 270
elif k == 2:
rot = 0
elif k == 3:
rot = 90
rail[row][col] = t_utils.rotate_transition(int('0010000000000000', 2), rot)
num_insertions += 1
break
else:
# can I get valid transitions by removing a single
# neighboring cell?
bestk = -1
besttrans = []
for k in range(4):
tmp_template = valid_template[:]
tmp_template[k] = -1
possible_cell_transitions = get_matching_templates(tmp_template)
if len(possible_cell_transitions) > len(besttrans):
besttrans = possible_cell_transitions
bestk = k
if bestk >= 0:
# Replace the corresponding cell with None, append it
# to cells to fill, fill in a transition in the current
# cell.
replace_row = row - 1
replace_col = col
if bestk == 1:
replace_row = row
replace_col = col + 1
elif bestk == 2:
replace_row = row + 1
replace_col = col
elif bestk == 3:
replace_row = row
replace_col = col - 1
cells_to_fill.append((replace_row, replace_col))
rail[replace_row][replace_col] = None
possible_transitions, possible_probabilities = zip(*besttrans)
possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities]
rail[row][col] = np.random.choice(possible_transitions,
p=possible_probabilities)
num_insertions += 1
else:
print('WARNING: still nothing!')
rail[row][col] = int('0000000000000000', 2)
num_insertions += 1
pass
else:
possible_transitions, possible_probabilities = zip(*possible_cell_transitions)
possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities]
rail[row][col] = np.random.choice(possible_transitions,
p=possible_probabilities)
num_insertions += 1
if num_insertions == MAX_INSERTIONS:
# Failed to generate a valid level; try again for a number of times
attempt_number += 1
else:
break
def rail_from_grid_transition_map(rail_map, optionals=None) -> RailGenerator:
return RailFromGridGen(rail_map, optionals)
if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH:
print('ERROR: failed to generate level')
# Finally pad the border of the map with dead-ends to avoid border issues;
# at most 1 transition in the neigh cell
for r in range(height):
# Check for transitions coming from [r][1] to WEST
max_bit = 0
neigh_trans = rail[r][1]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & 1)
if max_bit:
rail[r][0] = t_utils.rotate_transition(int('0010000000000000', 2), 270)
else:
rail[r][0] = int('0000000000000000', 2)
# Check for transitions coming from [r][-2] to EAST
max_bit = 0
neigh_trans = rail[r][-2]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 2))
if max_bit:
rail[r][-1] = t_utils.rotate_transition(int('0010000000000000', 2),
90)
else:
rail[r][-1] = int('0000000000000000', 2)
for c in range(width):
# Check for transitions coming from [1][c] to NORTH
max_bit = 0
neigh_trans = rail[1][c]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 3))
if max_bit:
rail[0][c] = int('0010000000000000', 2)
else:
rail[0][c] = int('0000000000000000', 2)
# Check for transitions coming from [-2][c] to SOUTH
max_bit = 0
neigh_trans = rail[-2][c]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 1))
if max_bit:
rail[-1][c] = t_utils.rotate_transition(int('0010000000000000', 2), 180)
else:
rail[-1][c] = int('0000000000000000', 2)
# For display only, wrong levels
for r in range(height):
for c in range(width):
if rail[r][c] is None:
rail[r][c] = int('0000000000000000', 2)
tmp_rail = np.asarray(rail, dtype=np.uint16)
return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
return_rail.grid = tmp_rail
return return_rail, None
return generator
def sparse_rail_generator(*args, **kwargs):
return SparseRailGen(*args, **kwargs)
def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_rails_between_cities: int = 4,
max_rails_in_city: int = 4, seed: int = 1) -> RailGenerator:
"""
Generates railway networks with cities and inner city rails
class SparseRailGen(RailGen):
Parameters
----------
max_num_cities : int
Max number of cities to build. The generator tries to achieve this numbers given all the parameters
grid_mode: Bool
How to distribute the cities in the path, either equally in a grid or random
max_rails_between_cities: int
Max number of rails connecting to a city. This is only the number of connection points at city boarder.
Number of tracks drawn inbetween cities can still vary
max_rails_in_city: int
Number of parallel tracks in the city. This represents the number of tracks in the trainstations
seed: int
Initiate the seed
def __init__(self, max_num_cities: int = 2, grid_mode: bool = False, max_rails_between_cities: int = 2,
max_rail_pairs_in_city: int = 2, seed=None) -> RailGenerator:
"""
Generates railway networks with cities and inner city rails
Returns
-------
Returns the rail generator object to the rail env constructor
"""
Parameters
----------
max_num_cities : int
Max number of cities to build. The generator tries to achieve this numbers given all the parameters
grid_mode: Bool
How to distribute the cities in the path, either equally in a grid or random
max_rails_between_cities: int
Max number of rails connecting to a city. This is only the number of connection points at city boarder.
Number of tracks drawn inbetween cities can still vary
max_rail_pairs_in_city: int
Number of parallel tracks in the city. This represents the number of tracks in the trainstations
seed: int
Initiate the seed
def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct:
Returns
-------
Returns the rail generator object to the rail env constructor
"""
self.max_num_cities = max_num_cities
self.grid_mode = grid_mode
self.max_rails_between_cities = max_rails_between_cities
self.max_rail_pairs_in_city = max_rail_pairs_in_city
self.seed = seed
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
np_random: RandomState = None) -> RailGenerator:
"""
Parameters
......@@ -607,78 +169,83 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
'train_stations': locations of train stations for start and targets
'city_orientations' : orientation of cities
"""
np.random.seed(seed + num_resets)
if self.seed is not None:
np_random = RandomState(self.seed)
elif np_random is None:
np_random = RandomState(np.random.randint(2 ** 32))
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
# NEW : SCHED CONST (Pairs of rails (1,2,3 pairs))
min_nr_rail_pairs_in_city = 1 # (min pair must be 1)
rail_pairs_in_city = min_nr_rail_pairs_in_city if self.max_rail_pairs_in_city < min_nr_rail_pairs_in_city else self.max_rail_pairs_in_city # (pairs can be 1,2,3)
rails_between_cities = (rail_pairs_in_city * 2) if self.max_rails_between_cities > (
rail_pairs_in_city * 2) else self.max_rails_between_cities
# We compute the city radius by the given max number of rails it can contain.
# The radius is equal to the number of tracks divided by 2
# We add 2 cells to avoid that track lenght is to shot
# We add 2 cells to avoid that track lenght is to short
city_padding = 2
city_radius = int(np.ceil((max_rails_in_city) // 2)) + city_padding
# We use ceil if we get uneven numbers of city radius. This is to guarantee that all rails fit within the city.
city_radius = int(np.ceil((rail_pairs_in_city * 2) / 2)) + city_padding
vector_field = np.zeros(shape=(height, width)) - 1.
min_nr_rails_in_city = 2
rails_in_city = min_nr_rails_in_city if max_rails_in_city < min_nr_rails_in_city else max_rails_in_city
rails_between_cities = rails_in_city if max_rails_between_cities > rails_in_city else max_rails_between_cities
# Calculate the max number of cities allowed
# and reduce the number of cities to build to avoid problems
max_feasible_cities = min(max_num_cities,
max_feasible_cities = min(self.max_num_cities,
((height - 2) // (2 * (city_radius + 1))) * ((width - 2) // (2 * (city_radius + 1))))
if max_feasible_cities < 2:
sys.exit("[ABORT] Cannot fit more than one city in this map, no feasible environment possible! Aborting.")
# sys.exit("[ABORT] Cannot fit more than one city in this map, no feasible environment possible! Aborting.")
raise ValueError("ERROR: Cannot fit more than one city in this map, no feasible environment possible!")
# Evenly distribute cities
if grid_mode:
city_positions = _generate_evenly_distr_city_positions(max_feasible_cities, city_radius, width,
height)
if self.grid_mode:
city_positions = self._generate_evenly_distr_city_positions(max_feasible_cities, city_radius, width,
height)
# Distribute cities randomlz
else:
city_positions = _generate_random_city_positions(max_feasible_cities, city_radius, width, height)
city_positions = self._generate_random_city_positions(max_feasible_cities, city_radius, width, height,
np_random=np_random)
# reduce num_cities if less were generated in random mode
num_cities = len(city_positions)
# If random generation failed just put the cities evenly
if num_cities < 2:
warnings.warn("[WARNING] Changing to Grid mode to place at least 2 cities.")
city_positions = _generate_evenly_distr_city_positions(max_feasible_cities, city_radius, width,
height)
city_positions = self._generate_evenly_distr_city_positions(max_feasible_cities, city_radius, width,
height)
num_cities = len(city_positions)
# Set up connection points for all cities
inner_connection_points, outer_connection_points, city_orientations, city_cells = \
_generate_city_connection_points(
self._generate_city_connection_points(
city_positions, city_radius, vector_field, rails_between_cities,
rails_in_city)
rail_pairs_in_city, np_random=np_random)
# Connect the cities through the connection points
inter_city_lines = _connect_cities(city_positions, outer_connection_points, city_cells,
rail_trans, grid_map)
inter_city_lines = self._connect_cities(city_positions, outer_connection_points, city_cells,
rail_trans, grid_map)
# Build inner cities
free_rails = _build_inner_cities(city_positions, inner_connection_points,
outer_connection_points,
rail_trans,
grid_map)
free_rails = self._build_inner_cities(city_positions, inner_connection_points,
outer_connection_points,
rail_trans,
grid_map)
# Populate cities
train_stations = _set_trainstation_positions(city_positions, city_radius, free_rails)
train_stations = self._set_trainstation_positions(city_positions, city_radius, free_rails)
# Fix all transition elements
_fix_transitions(city_cells, inter_city_lines, grid_map, vector_field)
self._fix_transitions(city_cells, inter_city_lines, grid_map, vector_field)
return grid_map, {'agents_hints': {
'num_agents': num_agents,
'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}}
def _generate_random_city_positions(num_cities: int, city_radius: int, width: int,
height: int) -> (IntVector2DArray, IntVector2DArray):
def _generate_random_city_positions(self, num_cities: int, city_radius: int, width: int,
height: int, np_random: RandomState = None) -> Tuple[
IntVector2DArray, IntVector2DArray]:
"""
Distribute the cities randomly in the environment while respecting city sizes and guaranteeing that they
don't overlap.
......@@ -699,32 +266,45 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
Returns a list of all city positions as coordinates (x,y)
"""
city_positions: IntVector2DArray = []
for city_idx in range(num_cities):
too_close = True
tries = 0
while too_close:
row = city_radius + 1 + np.random.randint(height - 2 * (city_radius + 1))
col = city_radius + 1 + np.random.randint(width - 2 * (city_radius + 1))
too_close = False
# Check distance to cities
for city_pos in city_positions:
if _are_cities_overlapping((row, col), city_pos, 2 * (city_radius + 1) + 1):
too_close = True
if not too_close:
city_positions.append((row, col))
tries += 1
if tries > 200:
warnings.warn(
"Could not set all required cities!")
break
# We track a grid of allowed indexes that can be sampled from for creating a new city
# This removes the old sampling method of retrying a random sample on failure
allowed_grid = np.zeros((height, width), dtype=np.uint8)
city_radius_pad1 = city_radius + 1
# Borders have to be not allowed from the start
# allowed_grid == 1 indicates locations that are allowed
allowed_grid[city_radius_pad1:-city_radius_pad1, city_radius_pad1:-city_radius_pad1] = 1
for _ in range(num_cities):
allowed_indexes = np.where(allowed_grid == 1)
num_allowed_points = len(allowed_indexes[0])
if num_allowed_points == 0:
break
# Sample one of the allowed indexes
point_index = np_random.randint(num_allowed_points)
row = int(allowed_indexes[0][point_index])
col = int(allowed_indexes[1][point_index])
# Need to block city radius and extra margin so that next sampling is correct
# Clipping handles the case for negative indexes being generated
row_start = max(0, row - 2 * city_radius_pad1)
col_start = max(0, col - 2 * city_radius_pad1)
row_end = row + 2 * city_radius_pad1 + 1
col_end = col + 2 * city_radius_pad1 + 1
allowed_grid[row_start: row_end, col_start: col_end] = 0
city_positions.append((row, col))
created_cites = len(city_positions)
if created_cites < num_cities:
city_warning = f"Could not set all required cities! Created {created_cites}/{num_cities}"
warnings.warn(city_warning)
return city_positions
def _generate_evenly_distr_city_positions(num_cities: int, city_radius: int, width: int, height: int
) -> (IntVector2DArray, IntVector2DArray):
def _generate_evenly_distr_city_positions(self, num_cities: int, city_radius: int, width: int, height: int
) -> Tuple[IntVector2DArray, IntVector2DArray]:
"""
Distribute the cities in an evenly spaced grid
......@@ -745,13 +325,12 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
"""
aspect_ratio = height / width
# Compute max numbe of possible cities per row and col.
# Respect padding at edges of environment
# Respect padding between cities
padding = 2
city_size = 2 * (city_radius + 1)
max_cities_per_row =int((height - padding) // city_size)
max_cities_per_row = int((height - padding) // city_size)
max_cities_per_col = int((width - padding) // city_size)
# Choose number of cities per row.
......@@ -770,12 +349,13 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
city_positions.append((row, col))
return city_positions
def _generate_city_connection_points(city_positions: IntVector2DArray, city_radius: int,
def _generate_city_connection_points(self, city_positions: IntVector2DArray, city_radius: int,
vector_field: IntVector2DArray, rails_between_cities: int,
rails_in_city: int = 2) -> (List[List[List[IntVector2D]]],
List[List[List[IntVector2D]]],
List[np.ndarray],
List[Grid4TransitionsEnum]):
rail_pairs_in_city: int = 1, np_random: RandomState = None) -> Tuple[
List[List[List[IntVector2D]]],
List[List[List[IntVector2D]]],
List[np.ndarray],
List[Grid4TransitionsEnum]]:
"""
Generate the city connection points. Internal connection points are used to generate the parallel paths
within the city.
......@@ -792,7 +372,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
Each cell contains the prefered orientation of cells. If no prefered orientation is present it is set to -1
rails_between_cities: int
Number of rails that connect out from the city
rails_in_city: int
rail_pairs_in_city: int
Number of rails within the city
Returns
......@@ -813,39 +393,43 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
outer_connection_points: List[List[List[IntVector2D]]] = []
city_orientations: List[Grid4TransitionsEnum] = []
city_cells: IntVector2DArray = []
for city_position in city_positions:
# Chose the directions where close cities are situated
neighb_dist = []
for neighbour_city in city_positions:
neighb_dist.append(Vec2dOperations.get_manhattan_distance(city_position, neighbour_city))
closest_neighb_idx = argsort(neighb_dist)
closest_neighb_idx = self.__class__.argsort(neighb_dist)
# Store the directions to these neighbours and orient city to face closest neighbour
connection_sides_idx = []
idx = 1
if grid_mode:
current_closest_direction = np.random.randint(4)
if self.grid_mode:
current_closest_direction = np_random.randint(4)
else:
current_closest_direction = direction_to_point(city_position, city_positions[closest_neighb_idx[idx]])
connection_sides_idx.append(current_closest_direction)
connection_sides_idx.append((current_closest_direction + 2) % 4)
city_orientations.append(current_closest_direction)
city_cells.extend(_get_cells_in_city(city_position, city_radius, city_orientations[-1], vector_field))
city_cells.extend(self._get_cells_in_city(city_position, city_radius, city_orientations[-1], vector_field))
# set the number of tracks within a city, at least 2 tracks per city
connections_per_direction = np.zeros(4, dtype=int)
nr_of_connection_points = np.random.randint(2, rails_in_city + 1)
# NEW : SCHED CONST
nr_of_connection_points = np_random.randint(1, rail_pairs_in_city + 1) * 2 # can be (1,2,3)*2 = (2,4,6)
for idx in connection_sides_idx:
connections_per_direction[idx] = nr_of_connection_points
connection_points_coordinates_inner: List[List[IntVector2D]] = [[] for i in range(4)]
connection_points_coordinates_outer: List[List[IntVector2D]] = [[] for i in range(4)]
number_of_out_rails = np.random.randint(1, min(rails_between_cities, nr_of_connection_points) + 1)
number_of_out_rails = np_random.randint(1, min(rails_between_cities, nr_of_connection_points) + 1)
start_idx = int((nr_of_connection_points - number_of_out_rails) / 2)
for direction in range(4):
connection_slots = np.arange(nr_of_connection_points) - start_idx
# Offset the rails away from the center of the city
offset_distances = np.arange(nr_of_connection_points) - int(nr_of_connection_points / 2)
# The clipping helps ofsetting one side more than the other to avoid switches at same locations
# The magic number plus one is added such that all points have at least one offset
inner_point_offset = np.abs(offset_distances) + np.clip(offset_distances, 0, 1) + 1
for connection_idx in range(connections_per_direction[direction]):
if direction == 0:
tmp_coordinates = (
......@@ -879,7 +463,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
outer_connection_points.append(connection_points_coordinates_outer)
return inner_connection_points, outer_connection_points, city_orientations, city_cells
def _connect_cities(city_positions: IntVector2DArray, connection_points: List[List[List[IntVector2D]]],
def _connect_cities(self, city_positions: IntVector2DArray, connection_points: List[List[List[IntVector2D]]],
city_cells: IntVector2DArray,
rail_trans: RailEnvTransitions, grid_map: RailEnvTransitions) -> List[IntVector2DArray]:
"""
......@@ -909,12 +493,11 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
grid4_directions = [Grid4TransitionsEnum.NORTH, Grid4TransitionsEnum.EAST, Grid4TransitionsEnum.SOUTH,
Grid4TransitionsEnum.WEST]
for current_city_idx in np.arange(len(city_positions)):
closest_neighbours = _closest_neighbour_in_grid4_directions(current_city_idx, city_positions)
closest_neighbours = self._closest_neighbour_in_grid4_directions(current_city_idx, city_positions)
for out_direction in grid4_directions:
neighbour_idx = get_closest_neighbour_for_direction(closest_neighbours, out_direction)
neighbour_idx = self.get_closest_neighbour_for_direction(closest_neighbours, out_direction)
for city_out_connection_point in connection_points[current_city_idx][out_direction]:
......@@ -927,17 +510,19 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
if tmp_dist < min_connection_dist:
min_connection_dist = tmp_dist
neighbour_connection_point = tmp_in_connection_point
new_line = connect_rail_in_grid_map(grid_map, city_out_connection_point, neighbour_connection_point,
rail_trans, flip_start_node_trans=False,
flip_end_node_trans=False, respect_transition_validity=False,
avoid_rail=True,
forbidden_cells=city_cells)
if len(new_line) == 0:
warnings.warn("[WARNING] No line added between stations")
elif new_line[-1] != neighbour_connection_point or new_line[0] != city_out_connection_point:
warnings.warn("[WARNING] Unable to connect requested stations")
all_paths.extend(new_line)
return all_paths
def get_closest_neighbour_for_direction(closest_neighbours, out_direction):
def get_closest_neighbour_for_direction(self, closest_neighbours, out_direction):
"""
Given a list of clostest neighbours in each direction this returns the city index of the neighbor in a given
direction. Direction is a 90 degree cone facing the desired directiont.
......@@ -974,9 +559,11 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
return closest_neighbours[(out_direction + 2) % 4] # clockwise
def _build_inner_cities(city_positions: IntVector2DArray, inner_connection_points: List[List[List[IntVector2D]]],
def _build_inner_cities(self, city_positions: IntVector2DArray,
inner_connection_points: List[List[List[IntVector2D]]],
outer_connection_points: List[List[List[IntVector2D]]], rail_trans: RailEnvTransitions,
grid_map: GridTransitionMap) -> (List[IntVector2DArray], List[List[List[IntVector2D]]]):
grid_map: GridTransitionMap) -> Tuple[
List[IntVector2DArray], List[List[List[IntVector2D]]]]:
"""
Set the parallel tracks within the city. The center track of the city is of the length of the city, the lenght
of the tracks decrease by 2 for every parallel track away from the center
......@@ -1024,8 +611,8 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
source = inner_connection_points[current_city][boarder][track_id]
target = inner_connection_points[current_city][opposite_boarder][track_id]
current_track = connect_straight_line_in_grid_map(grid_map, source, target, rail_trans)
free_rails[current_city].append(current_track)
for track_id in range(nr_of_connection_points):
source = inner_connection_points[current_city][boarder][track_id]
target = inner_connection_points[current_city][opposite_boarder][track_id]
......@@ -1042,10 +629,9 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
target_outer = outer_connection_points[current_city][opposite_boarder][track_id - start_idx]
connect_straight_line_in_grid_map(grid_map, source, source_outer, rail_trans)
connect_straight_line_in_grid_map(grid_map, target, target_outer, rail_trans)
return free_rails
def _set_trainstation_positions(city_positions: IntVector2DArray, city_radius: int,
def _set_trainstation_positions(self, city_positions: IntVector2DArray, city_radius: int,
free_rails: List[List[List[IntVector2D]]]) -> List[List[Tuple[IntVector2D, int]]]:
"""
Populate the cities with possible start and end positions. Trainstations are set on the center of each paralell
......@@ -1074,7 +660,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
train_stations[current_city].append((possible_location, track_nbr))
return train_stations
def _fix_transitions(city_cells: IntVector2DArray, inter_city_lines: List[IntVector2DArray],
def _fix_transitions(self, city_cells: IntVector2DArray, inter_city_lines: List[IntVector2DArray],
grid_map: GridTransitionMap, vector_field):
"""
Check and fix transitions of all the cells that were modified. This is necessary because we ignore validity
......@@ -1111,7 +697,8 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
for cell in range(rails_to_fix_cnt):
grid_map.fix_transitions((rails_to_fix[3 * cell], rails_to_fix[3 * cell + 1]), rails_to_fix[3 * cell + 2])
def _closest_neighbour_in_grid4_directions(current_city_idx: int, city_positions: IntVector2DArray) -> List[int]:
def _closest_neighbour_in_grid4_directions(self, current_city_idx: int, city_positions: IntVector2DArray) -> List[
int]:
"""
Finds the closest city in each direction of the current city
Parameters
......@@ -1146,6 +733,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
return closest_neighbour
@staticmethod
def argsort(seq):
"""
Same as Numpy sort but for lists
......@@ -1162,7 +750,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
# http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python
return sorted(range(len(seq)), key=seq.__getitem__)
def _get_cells_in_city(center: IntVector2D, radius: int, city_orientation: int,
def _get_cells_in_city(self, center: IntVector2D, radius: int, city_orientation: int,
vector_field: IntVector2DArray) -> IntVector2DArray:
"""
Function the collect cells of a city. It also populates the vector field accoring to the orientation of the
......@@ -1199,6 +787,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
vector_field[cell] = align_cell_to_city(center, city_orientation, cell)
return city_cells
@staticmethod
def _are_cities_overlapping(center_1, center_2, radius):
"""
Check if two cities overlap. That is we check if two squares with certain edge length and position overlap
......@@ -1217,5 +806,3 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
Returns True if the cities overlap and False otherwise
"""
return np.abs(center_1[0] - center_2[0]) < radius and np.abs(center_1[1] - center_2[1]) < radius
return generator
from typing import NamedTuple, Tuple, List, Dict
# A way point is the entry into a cell defined by
# - the row and column coordinates of the cell entered
# - direction, in which the agent is facing to enter the cell.
# This induces a graph on top of the FLATland cells:
# - four possible way points per cell
# - edges are the possible transitions in the cell.
Waypoint = NamedTuple('Waypoint', [('position', Tuple[int, int]), ('direction', int)])
# A train run is represented by the waypoints traversed and the times of traversal
# The terminology follows https://github.com/crowdAI/train-schedule-optimisation-challenge-starter-kit/blob/master/documentation/output_data_model.md
TrainrunWaypoint = NamedTuple('TrainrunWaypoint', [
('scheduled_at', int),
('waypoint', Waypoint)
])
# A train run is the list of an agent's way points and their scheduled time
Trainrun = List[TrainrunWaypoint]
TrainrunDict = Dict[int, Trainrun]
"""Schedule generators (railway undertaking, "EVU")."""
import warnings
from typing import Tuple, List, Callable, Mapping, Optional, Any
import msgpack
import numpy as np
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.schedule_utils import Schedule
AgentPosition = Tuple[int, int]
ScheduleGenerator = Callable[[GridTransitionMap, int, Optional[Any], Optional[int]], Schedule]
def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, float] = None,
seed: int = None) -> List[float]:
"""
Parameters
----------
nb_agents : int
The number of agents to generate a speed for
speed_ratio_map : Mapping[float,float]
A map of speeds mappint to their ratio of appearance. The ratios must sum up to 1.
Returns
-------
List[float]
A list of size nb_agents of speeds with the corresponding probabilistic ratios.
"""
if seed:
np.random.seed(seed)
if speed_ratio_map is None:
return [1.0] * nb_agents
nb_classes = len(speed_ratio_map.keys())
speed_ratio_map_as_list: List[Tuple[float, float]] = list(speed_ratio_map.items())
speed_ratios = list(map(lambda t: t[1], speed_ratio_map_as_list))
speeds = list(map(lambda t: t[0], speed_ratio_map_as_list))
return list(map(lambda index: speeds[index], np.random.choice(nb_classes, nb_agents, p=speed_ratios)))
def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None, seed: int = 1) -> ScheduleGenerator:
"""
Generator used to generate the levels of Round 1 in the Flatland Challenge. It can only be used together
with complex_rail_generator. It places agents at end and start points provided by the rail generator.
It assigns speeds to the different agents according to the speed_ratio_map
:param speed_ratio_map: Speed ratios of all agents. They are probabilities of all different speeds and have to
add up to 1.
:param seed: Initiate random seed generator
:return:
"""
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0) -> Schedule:
"""
The generator that assigns tasks to all the agents
:param rail: Rail infrastructure given by the rail_generator
:param num_agents: Number of agents to include in the schedule
:param hints: Hints provided by the rail_generator These include positions of start/target positions
:param num_resets: How often the generator has been reset.
:return: Returns the generator to the rail constructor
"""
_runtime_seed = seed + num_resets
np.random.seed(_runtime_seed)
start_goal = hints['start_goal']
start_dir = hints['start_dir']
agents_position = [sg[0] for sg in start_goal[:num_agents]]
agents_target = [sg[1] for sg in start_goal[:num_agents]]
agents_direction = start_dir[:num_agents]
if speed_ratio_map:
speeds = speed_initialization_helper(num_agents, speed_ratio_map, seed=_runtime_seed)
else:
speeds = [1.0] * len(agents_position)
return Schedule(agent_positions=agents_position, agent_directions=agents_direction,
agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None)
return generator
def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, seed: int = 1) -> ScheduleGenerator:
"""
This is the schedule generator which is used for Round 2 of the Flatland challenge. It produces schedules
to railway networks provided by sparse_rail_generator.
:param speed_ratio_map: Speed ratios of all agents. They are probabilities of all different speeds and have to
add up to 1.
:param seed: Initiate random seed generator
"""
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0) -> Schedule:
"""
The generator that assigns tasks to all the agents
:param rail: Rail infrastructure given by the rail_generator
:param num_agents: Number of agents to include in the schedule
:param hints: Hints provided by the rail_generator These include positions of start/target positions
:param num_resets: How often the generator has been reset.
:return: Returns the generator to the rail constructor
"""
_runtime_seed = seed + num_resets
np.random.seed(_runtime_seed)
train_stations = hints['train_stations']
city_positions = hints['city_positions']
city_orientation = hints['city_orientations']
max_num_agents = hints['num_agents']
city_orientations = hints['city_orientations']
if num_agents > max_num_agents:
num_agents = max_num_agents
warnings.warn("Too many agents! Changes number of agents.")
# Place agents and targets within available train stations
agents_position = []
agents_target = []
agents_direction = []
for agent_idx in range(num_agents):
infeasible_agent = True
tries = 0
while infeasible_agent:
tries += 1
infeasible_agent = False
# Set target for agent
city_idx = np.random.choice(len(city_positions), 2, replace=False)
start_city = city_idx[0]
target_city = city_idx[1]
start_idx = np.random.choice(np.arange(len(train_stations[start_city])))
target_idx = np.random.choice(np.arange(len(train_stations[target_city])))
start = train_stations[start_city][start_idx]
target = train_stations[target_city][target_idx]
while start[1] % 2 != 0:
start_idx = np.random.choice(np.arange(len(train_stations[start_city])))
start = train_stations[start_city][start_idx]
while target[1] % 2 != 1:
target_idx = np.random.choice(np.arange(len(train_stations[target_city])))
target = train_stations[target_city][target_idx]
possible_orientations = [city_orientation[start_city],
(city_orientation[start_city] + 2) % 4]
agent_orientation = np.random.choice(possible_orientations)
if not rail.check_path_exists(start[0], agent_orientation, target[0]):
agent_orientation = (agent_orientation + 2) % 4
if not (rail.check_path_exists(start[0], agent_orientation, target[0])):
infeasible_agent = True
if tries >= 100:
warnings.warn("Did not find any possible path, check your parameters!!!")
break
agents_position.append((start[0][0], start[0][1]))
agents_target.append((target[0][0], target[0][1]))
agents_direction.append(agent_orientation)
# Orient the agent correctly
if speed_ratio_map:
speeds = speed_initialization_helper(num_agents, speed_ratio_map, seed=_runtime_seed)
else:
speeds = [1.0] * len(agents_position)
return Schedule(agent_positions=agents_position, agent_directions=agents_direction,
agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None)
return generator
def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] = None,
seed: int = 1) -> ScheduleGenerator:
"""
Given a `rail` GridTransitionMap, return a random placement of agents (initial position, direction and target).
Parameters
----------
speed_ratio_map : Optional[Mapping[float, float]]
A map of speeds mapping to their ratio of appearance. The ratios must sum up to 1.
Returns
-------
Tuple[List[Tuple[int,int]], List[Tuple[int,int]], List[Tuple[int,int]], List[float]]
initial positions, directions, targets speeds
"""
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None,
num_resets: int = 0) -> Schedule:
_runtime_seed = seed + num_resets
np.random.seed(_runtime_seed)
valid_positions = []
for r in range(rail.height):
for c in range(rail.width):
if rail.get_full_transitions(r, c) > 0:
valid_positions.append((r, c))
if len(valid_positions) == 0:
return Schedule(agent_positions=[], agent_directions=[],
agent_targets=[], agent_speeds=[], agent_malfunction_rates=None)
if len(valid_positions) < num_agents:
warnings.warn("schedule_generators: len(valid_positions) < num_agents")
return Schedule(agent_positions=[], agent_directions=[],
agent_targets=[], agent_speeds=[], agent_malfunction_rates=None)
agents_position_idx = [i for i in np.random.choice(len(valid_positions), num_agents, replace=False)]
agents_position = [valid_positions[agents_position_idx[i]] for i in range(num_agents)]
agents_target_idx = [i for i in np.random.choice(len(valid_positions), num_agents, replace=False)]
agents_target = [valid_positions[agents_target_idx[i]] for i in range(num_agents)]
update_agents = np.zeros(num_agents)
re_generate = True
cnt = 0
while re_generate:
cnt += 1
if cnt > 1:
print("re_generate cnt={}".format(cnt))
if cnt > 1000:
raise Exception("After 1000 re_generates still not success, giving up.")
# update position
for i in range(num_agents):
if update_agents[i] == 1:
x = np.setdiff1d(np.arange(len(valid_positions)), agents_position_idx)
agents_position_idx[i] = np.random.choice(x)
agents_position[i] = valid_positions[agents_position_idx[i]]
x = np.setdiff1d(np.arange(len(valid_positions)), agents_target_idx)
agents_target_idx[i] = np.random.choice(x)
agents_target[i] = valid_positions[agents_target_idx[i]]
update_agents = np.zeros(num_agents)
# agents_direction must be a direction for which a solution is
# guaranteed.
agents_direction = [0] * num_agents
re_generate = False
for i in range(num_agents):
valid_movements = []
for direction in range(4):
position = agents_position[i]
moves = rail.get_transitions(position[0], position[1], direction)
for move_index in range(4):
if moves[move_index]:
valid_movements.append((direction, move_index))
valid_starting_directions = []
for m in valid_movements:
new_position = get_new_position(agents_position[i], m[1])
if m[0] not in valid_starting_directions and rail.check_path_exists(new_position, m[1],
agents_target[i]):
valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0:
update_agents[i] = 1
warnings.warn(
"reset position for agent[{}]: {} -> {}".format(i, agents_position[i], agents_target[i]))
re_generate = True
break
else:
agents_direction[i] = valid_starting_directions[
np.random.choice(len(valid_starting_directions), 1)[0]]
agents_speed = speed_initialization_helper(num_agents, speed_ratio_map, seed=_runtime_seed)
return Schedule(agent_positions=agents_position, agent_directions=agents_direction,
agent_targets=agents_target, agent_speeds=agents_speed, agent_malfunction_rates=None)
return generator
def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator:
"""
Utility to load pickle file
Parameters
----------
input_file : Pickle file generated by env.save() or editor
Returns
-------
Tuple[List[Tuple[int,int]], List[Tuple[int,int]], List[Tuple[int,int]], List[float]]
initial positions, directions, targets speeds
"""
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None,
num_resets: int = 0) -> Schedule:
if load_from_package is not None:
from importlib_resources import read_binary
load_data = read_binary(load_from_package, filename)
else:
with open(filename, "rb") as file_in:
load_data = file_in.read()
data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8')
agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
# setup with loaded data
agents_position = [a.initial_position for a in agents]
agents_direction = [a.direction for a in agents]
agents_target = [a.target for a in agents]
agents_speed = [a.speed_data['speed'] for a in agents]
agents_malfunction = [a.malfunction_data['malfunction_rate'] for a in agents]
return Schedule(agent_positions=agents_position, agent_directions=agents_direction,
agent_targets=agents_target, agent_speeds=agents_speed, agent_malfunction_rates=None)
return generator
raise ImportError(" Schedule Generators is now renamed to line_generators, any reference to schedule should be replaced with line")
\ No newline at end of file
from typing import List, NamedTuple
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid_utils import IntVector2DArray
Schedule = NamedTuple('Schedule', [('agent_positions', IntVector2DArray),
('agent_directions', List[Grid4TransitionsEnum]),
('agent_targets', IntVector2DArray),
('agent_speeds', List[float]),
('agent_malfunction_rates', List[int])])
"""Rail generators (infrastructure manager, "Infrastrukturbetreiber")."""
import sys
import warnings
from typing import Callable, Tuple, Optional, Dict, List
import numpy as np
from numpy.random.mtrand import RandomState
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_direction, mirror, direction_to_point
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.grid.grid_utils import distance_on_rail, IntVector2DArray, IntVector2D, \
Vec2dOperations
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map, connect_straight_line_in_grid_map, \
fix_inner_nodes, align_cell_to_city
from flatland.envs import persistence
from flatland.envs.rail_generators import RailGeneratorProduct, RailGenerator
from flatland.core.grid.grid_utils import position_to_coordinate
from flatland.envs.agent_utils import TrainState
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.step_utils.transition_utils import check_valid_action
def process_illegal_action(action: RailEnvActions):
if not RailEnvActions.is_action_valid(action):
return RailEnvActions.DO_NOTHING
else:
return RailEnvActions(action)
def process_do_nothing(state: TrainState, saved_action: RailEnvActions):
if state == TrainState.MOVING:
action = RailEnvActions.MOVE_FORWARD
elif saved_action:
action = saved_action
else:
action = RailEnvActions.DO_NOTHING
return action
def process_left_right(action, rail, position, direction):
if not check_valid_action(action, rail, position, direction):
action = RailEnvActions.MOVE_FORWARD
return action
def preprocess_action_when_waiting(action, state):
"""
Set action to DO_NOTHING if in waiting state
"""
if state == TrainState.WAITING:
action = RailEnvActions.DO_NOTHING
return action
def preprocess_raw_action(action, state, saved_action):
"""
Preprocesses actions to handle different situations of usage of action based on context
- DO_NOTHING is converted to FORWARD if train is moving
"""
action = process_illegal_action(action)
if action == RailEnvActions.DO_NOTHING:
action = process_do_nothing(state, saved_action)
return action
def preprocess_moving_action(action, rail, position, direction):
"""
LEFT/RIGHT is converted to FORWARD if left/right is not available and train is moving
FORWARD is converted to STOP_MOVING if leading to dead end?
"""
if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]:
action = process_left_right(action, rail, position, direction)
return action
\ No newline at end of file
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.step_utils.states import TrainState
class ActionSaver:
def __init__(self):
self.saved_action = None
@property
def is_action_saved(self):
return self.saved_action is not None
def __repr__(self):
return f"is_action_saved: {self.is_action_saved}, saved_action: {str(self.saved_action)}"
def save_action_if_allowed(self, action, state):
"""
Save the action if all conditions are met
1. It is a movement based action -> Forward, Left, Right
2. Action is not already saved
3. Agent is not already done
"""
if action.is_moving_action() and not self.is_action_saved and not state == TrainState.DONE:
self.saved_action = action
def clear_saved_action(self):
self.saved_action = None
def to_dict(self):
return {"saved_action": self.saved_action}
def from_dict(self, load_dict):
self.saved_action = load_dict['saved_action']
def __eq__(self, other):
return self.saved_action == other.saved_action