diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 01ce2908c438b48654e23720fdb29c217712dce9..f5b84a7d11910e1bf11c93a4bef5a955ad734ad3 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -39,7 +39,7 @@ class EnvAgentStatic(object): # number of time the agent had to stop, since the last time it broke down malfunction_data = attrib( default=Factory( - lambda: dict({'malfunction': 0, 'nr_malfunctions': 0, + lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0, 'moving_before_malfunction': False, 'fixed': True}))) status = attrib(default=RailAgentStatus.READY_TO_DEPART, type=RailAgentStatus) @@ -62,8 +62,10 @@ class EnvAgentStatic(object): malfunction_datas = [] for i in range(len(schedule.agent_positions)): malfunction_datas.append({'malfunction': 0, + 'malfunction_rate': schedule.agent_malfunction_rates[ + i] if schedule.agent_malfunction_rates is not None else 0., + 'next_malfunction': 0, 'nr_malfunctions': 0, - 'moving_before_malfunction': False, 'fixed': True}) return list(starmap(EnvAgentStatic, zip(schedule.agent_positions, diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index 58a7be34197b4ae3a30a3c85c3e1e131a3d33845..d3c5d78ae2296a3f89ac862887669aeedefe2ebb 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -79,7 +79,7 @@ def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None, se speeds = [1.0] * len(agents_position) return Schedule(agent_positions=agents_position, agent_directions=agents_direction, - agent_targets=agents_target, agent_speeds=speeds) + agent_targets=agents_target, agent_speeds=speeds,agent_malfunction_rates=None) return generator @@ -165,7 +165,7 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, see speeds = [1.0] * len(agents_position) return Schedule(agent_positions=agents_position, agent_directions=agents_direction, - agent_targets=agents_target, agent_speeds=speeds) + agent_targets=agents_target, agent_speeds=speeds,agent_malfunction_rates=None) return generator @@ -263,7 +263,7 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] = agents_speed = speed_initialization_helper(num_agents, speed_ratio_map, seed=_runtime_seed) return Schedule(agent_positions=agents_position, agent_directions=agents_direction, - agent_targets=agents_target, agent_speeds=agents_speed) + agent_targets=agents_target, agent_speeds=speeds,agent_malfunction_rates=None) return generator @@ -307,6 +307,6 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator: else: agents_speed = None return Schedule(agent_positions=agents_position, agent_directions=agents_direction, - agent_targets=agents_target, agent_speeds=agents_speed) + agent_targets=agents_target, agent_speeds=speeds,agent_malfunction_rates=None) return generator diff --git a/flatland/envs/schedule_utils.py b/flatland/envs/schedule_utils.py index c61d2f6b3ea1c28d5232fcfd54e664ee33b6238a..e89f170dbb87388bcecbc6b2e176ba277162a4db 100644 --- a/flatland/envs/schedule_utils.py +++ b/flatland/envs/schedule_utils.py @@ -6,4 +6,5 @@ from flatland.core.grid.grid_utils import IntVector2DArray Schedule = NamedTuple('Schedule', [('agent_positions', IntVector2DArray), ('agent_directions', List[Grid4TransitionsEnum]), ('agent_targets', IntVector2DArray), - ('agent_speeds', List[float])]) + ('agent_speeds', List[float]), + ('agent_malfunction_rates', List[int])])