Skip to content
Snippets Groups Projects
Commit b4c64e45 authored by hagrid67's avatar hagrid67
Browse files

first cut passing multiprocessing test

parent d583e91e
No related branches found
No related tags found
No related merge requests found
...@@ -15,6 +15,20 @@ from flatland.envs.agent_utils import RailAgentStatus, EnvAgent ...@@ -15,6 +15,20 @@ from flatland.envs.agent_utils import RailAgentStatus, EnvAgent
from flatland.utils.ordered_set import OrderedSet from flatland.utils.ordered_set import OrderedSet
Node = collections.namedtuple('Node', 'dist_own_target_encountered '
'dist_other_target_encountered '
'dist_other_agent_encountered '
'dist_potential_conflict '
'dist_unusable_switch '
'dist_to_next_branch '
'dist_min_to_target '
'num_agents_same_direction '
'num_agents_opposite_direction '
'num_agents_malfunctioning '
'speed_min_fractional '
'num_agents_ready_to_depart '
'childs')
class TreeObsForRailEnv(ObservationBuilder): class TreeObsForRailEnv(ObservationBuilder):
""" """
TreeObsForRailEnv object. TreeObsForRailEnv object.
...@@ -25,19 +39,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -25,19 +39,7 @@ class TreeObsForRailEnv(ObservationBuilder):
For details about the features in the tree observation see the get() function. For details about the features in the tree observation see the get() function.
""" """
Node = collections.namedtuple('Node', 'dist_own_target_encountered '
'dist_other_target_encountered '
'dist_other_agent_encountered '
'dist_potential_conflict '
'dist_unusable_switch '
'dist_to_next_branch '
'dist_min_to_target '
'num_agents_same_direction '
'num_agents_opposite_direction '
'num_agents_malfunctioning '
'speed_min_fractional '
'num_agents_ready_to_depart '
'childs')
tree_explored_actions_char = ['L', 'F', 'R', 'B'] tree_explored_actions_char = ['L', 'F', 'R', 'B']
...@@ -205,7 +207,8 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -205,7 +207,8 @@ class TreeObsForRailEnv(ObservationBuilder):
# Here information about the agent itself is stored # Here information about the agent itself is stored
distance_map = self.env.distance_map.get() distance_map = self.env.distance_map.get()
root_node_observation = TreeObsForRailEnv.Node(dist_own_target_encountered=0, dist_other_target_encountered=0, # was referring to TreeObsForRailEnv.Node
root_node_observation = Node(dist_own_target_encountered=0, dist_other_target_encountered=0,
dist_other_agent_encountered=0, dist_potential_conflict=0, dist_other_agent_encountered=0, dist_potential_conflict=0,
dist_unusable_switch=0, dist_to_next_branch=0, dist_unusable_switch=0, dist_to_next_branch=0,
dist_min_to_target=distance_map[ dist_min_to_target=distance_map[
...@@ -431,7 +434,8 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -431,7 +434,8 @@ class TreeObsForRailEnv(ObservationBuilder):
dist_to_next_branch = tot_dist dist_to_next_branch = tot_dist
dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction] dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction]
node = TreeObsForRailEnv.Node(dist_own_target_encountered=own_target_encountered, # TreeObsForRailEnv.Node
node = Node(dist_own_target_encountered=own_target_encountered,
dist_other_target_encountered=other_target_encountered, dist_other_target_encountered=other_target_encountered,
dist_other_agent_encountered=other_agent_encountered, dist_other_agent_encountered=other_agent_encountered,
dist_potential_conflict=potential_conflict, dist_potential_conflict=potential_conflict,
......
...@@ -17,17 +17,47 @@ from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map, conne ...@@ -17,17 +17,47 @@ from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map, conne
fix_inner_nodes, align_cell_to_city fix_inner_nodes, align_cell_to_city
from flatland.envs import persistence from flatland.envs import persistence
RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]] RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]]
""" A rail generator returns a RailGenerator Product, which is just
a GridTransitionMap followed by an (optional) dict/
"""
RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct] RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
class RailGen(object):
""" Base class for RailGen(erator) replacement
WIP to replace bare generators with classes / objects without unnamed local variables
which prevent pickling.
"""
def __init__(self, *args, **kwargs):
""" constructor to record any state to be reused in each "generation"
"""
pass
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
np_random: RandomState = None) -> RailGeneratorProduct:
pass
def __call__(self, *args, **kwargs) -> RailGeneratorProduct:
return self.generate(*args, **kwargs)
def empty_rail_generator() -> RailGenerator: def empty_rail_generator() -> RailGenerator:
return EmptyRailGen()
class EmptyRailGen(RailGen):
""" """
Returns a generator which returns an empty rail mail with no agents. Returns a generator which returns an empty rail mail with no agents.
Primarily used by the editor Primarily used by the editor
""" """
def generator(width: int, height: int, num_agents: int, num_resets: int = 0, def generate(width: int, height: int, num_agents: int, num_resets: int = 0,
np_random: RandomState = None) -> RailGenerator: np_random: RandomState = None) -> RailGenerator:
rail_trans = RailEnvTransitions() rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
...@@ -36,7 +66,6 @@ def empty_rail_generator() -> RailGenerator: ...@@ -36,7 +66,6 @@ def empty_rail_generator() -> RailGenerator:
return grid_map, None return grid_map, None
return generator
def complex_rail_generator(nr_start_goal=1, def complex_rail_generator(nr_start_goal=1,
...@@ -255,8 +284,19 @@ def rail_from_file(filename, load_from_package=None) -> RailGenerator: ...@@ -255,8 +284,19 @@ def rail_from_file(filename, load_from_package=None) -> RailGenerator:
return generator return generator
class RailFromGridGen(RailGen):
def __init__(self, rail_map):
self.rail_map = rail_map
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
np_random: RandomState = None) -> RailGenerator:
return self.rail_map, None
def rail_from_grid_transition_map(rail_map) -> RailGenerator: def rail_from_grid_transition_map(rail_map) -> RailGenerator:
return RailFromGridGen(rail_map)
def rail_from_grid_transition_map_old(rail_map) -> RailGenerator:
""" """
Utility to convert a rail given by a GridTransitionMap map with the correct Utility to convert a rail given by a GridTransitionMap map with the correct
16-bit transitions specifications. 16-bit transitions specifications.
...@@ -561,13 +601,6 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11, seed=1) -> R ...@@ -561,13 +601,6 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11, seed=1) -> R
class RailGen(object):
def __init__(self):
pass
def generate(self):
pass
def sparse_rail_generator(*args, **kwargs): def sparse_rail_generator(*args, **kwargs):
return SparseRailGen(*args, **kwargs) return SparseRailGen(*args, **kwargs)
......
...@@ -40,6 +40,20 @@ def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, ...@@ -40,6 +40,20 @@ def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float,
return list(map(lambda index: speeds[index], np_random.choice(nb_classes, nb_agents, p=speed_ratios))) return list(map(lambda index: speeds[index], np_random.choice(nb_classes, nb_agents, p=speed_ratios)))
class BaseSchedGen(object):
def __init__(self, speed_ratio_map: Mapping[float, float] = None, seed: int = 1):
self.speed_ratio_map = speed_ratio_map
self.seed = seed
def generate(self, rail: GridTransitionMap, num_agents: int, hints: Any=None, num_resets: int = 0,
np_random: RandomState = None) -> Schedule:
pass
def __call__(self, *args, **kwargs):
return self.generate(*args, **kwargs)
def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None, seed: int = 1) -> ScheduleGenerator: def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None, seed: int = 1) -> ScheduleGenerator:
""" """
...@@ -88,6 +102,10 @@ def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None, se ...@@ -88,6 +102,10 @@ def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None, se
def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, seed: int = 1) -> ScheduleGenerator: def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, seed: int = 1) -> ScheduleGenerator:
return SparseSchedGen(speed_ratio_map, seed)
class SparseSchedGen(BaseSchedGen):
""" """
This is the schedule generator which is used for Round 2 of the Flatland challenge. It produces schedules This is the schedule generator which is used for Round 2 of the Flatland challenge. It produces schedules
...@@ -97,7 +115,7 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, see ...@@ -97,7 +115,7 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, see
:param seed: Initiate random seed generator :param seed: Initiate random seed generator
""" """
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0, def generate(self, rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0,
np_random: RandomState = None) -> Schedule: np_random: RandomState = None) -> Schedule:
""" """
...@@ -109,7 +127,7 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, see ...@@ -109,7 +127,7 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, see
:return: Returns the generator to the rail constructor :return: Returns the generator to the rail constructor
""" """
_runtime_seed = seed + num_resets _runtime_seed = self.seed + num_resets
train_stations = hints['train_stations'] train_stations = hints['train_stations']
city_positions = hints['city_positions'] city_positions = hints['city_positions']
...@@ -162,8 +180,8 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, see ...@@ -162,8 +180,8 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, see
agents_direction.append(agent_orientation) agents_direction.append(agent_orientation)
# Orient the agent correctly # Orient the agent correctly
if speed_ratio_map: if self.speed_ratio_map:
speeds = speed_initialization_helper(num_agents, speed_ratio_map, seed=_runtime_seed, np_random=np_random) speeds = speed_initialization_helper(num_agents, self.speed_ratio_map, seed=_runtime_seed, np_random=np_random)
else: else:
speeds = [1.0] * len(agents_position) speeds = [1.0] * len(agents_position)
...@@ -178,11 +196,13 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, see ...@@ -178,11 +196,13 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, see
agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None, agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None,
max_episode_steps=max_episode_steps) max_episode_steps=max_episode_steps)
return generator
def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None, seed: int = 1) -> ScheduleGenerator:
return RandomSchedGen(speed_ratio_map, seed)
def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] = None,
seed: int = 1) -> ScheduleGenerator: class RandomSchedGen(BaseSchedGen):
""" """
Given a `rail` GridTransitionMap, return a random placement of agents (initial position, direction and target). Given a `rail` GridTransitionMap, return a random placement of agents (initial position, direction and target).
...@@ -197,9 +217,9 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] = ...@@ -197,9 +217,9 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] =
initial positions, directions, targets speeds initial positions, directions, targets speeds
""" """
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0, def generate(self, rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0,
np_random: RandomState = None) -> Schedule: np_random: RandomState = None) -> Schedule:
_runtime_seed = seed + num_resets _runtime_seed = self.seed + num_resets
valid_positions = [] valid_positions = []
for r in range(rail.height): for r in range(rail.height):
...@@ -270,7 +290,8 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] = ...@@ -270,7 +290,8 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] =
agents_direction[i] = valid_starting_directions[ agents_direction[i] = valid_starting_directions[
np_random.choice(len(valid_starting_directions), 1)[0]] np_random.choice(len(valid_starting_directions), 1)[0]]
agents_speed = speed_initialization_helper(num_agents, speed_ratio_map, seed=_runtime_seed, np_random=np_random) agents_speed = speed_initialization_helper(num_agents, self.speed_ratio_map, seed=_runtime_seed,
np_random=np_random)
# Compute max number of steps with given schedule # Compute max number of steps with given schedule
extra_time_factor = 1.5 # Factor to allow for more then minimal time extra_time_factor = 1.5 # Factor to allow for more then minimal time
...@@ -280,7 +301,6 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] = ...@@ -280,7 +301,6 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] =
agent_targets=agents_target, agent_speeds=agents_speed, agent_malfunction_rates=None, agent_targets=agents_target, agent_speeds=agents_speed, agent_malfunction_rates=None,
max_episode_steps=max_episode_steps) max_episode_steps=max_episode_steps)
return generator
def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator: def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment