diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 07fdd8d2b26485de888af81240f9c9f8c6d0533d..ee586e2071f7df26e7e17f553fbe9dc6867597e4 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -2,7 +2,7 @@ import time import numpy as np -from flatland.envs.malfunction_generators import malfunction_from_params +from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv @@ -16,12 +16,10 @@ np.random.seed(1) # Training on simple small tasks is the best way to get familiar with the environment # Use a the malfunction generator to break agents from time to time -stochastic_data = {'prop_malfunction': 0.3, # Percentage of defective agents - 'malfunction_rate': 30, # Rate of malfunction occurence - 'min_duration': 3, # Minimal duration of malfunction - 'max_duration': 20 # Max duration of malfunction - } - +stochastic_data = MalfunctionParameters(malfunction_rate=30, # Rate of malfunction occurence + min_duration=3, # Minimal duration of malfunction + max_duration=20 # Max duration of malfunction + ) # Custom observation builder TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) diff --git a/examples/introduction_flatland_2_1.py b/examples/introduction_flatland_2_1.py index 6c3313d801e31fb9fa32605c59c331198d1865d8..fa48acb23c9bf7e5f7c57df54abc733bc75ed2e7 100644 --- a/examples/introduction_flatland_2_1.py +++ b/examples/introduction_flatland_2_1.py @@ -3,7 +3,7 @@ import numpy as np # In Flatland you can use custom observation builders and predicitors # Observation builders generate the observation needed by the controller # Preditctors can be used to do short time prediction which can help in avoiding conflicts in the network -from flatland.envs.malfunction_generators import malfunction_from_params +from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters from flatland.envs.observations import GlobalObsForRailEnv # First of all we import the Flatland rail environment from flatland.envs.rail_env import RailEnv @@ -62,11 +62,10 @@ schedule_generator = sparse_schedule_generator(speed_ration_map) # We can furthermore pass stochastic data to the RailEnv constructor which will allow for stochastic malfunctions # during an episode. -stochastic_data = {'malfunction_rate': 12000, # Rate of malfunction occurence of single agent - 'min_duration': 15, # Minimal duration of malfunction - 'max_duration': 50 # Max duration of malfunction - } - +stochastic_data = MalfunctionParameters(malfunction_rate=10000, # Rate of malfunction occurence + min_duration=15, # Minimal duration of malfunction + max_duration=50 # Max duration of malfunction + ) # Custom observation builder without predictor observation_builder = GlobalObsForRailEnv() @@ -256,7 +255,7 @@ for step in range(500): next_obs, all_rewards, done, _ = env.step(action_dict) env_renderer.render_env(show=True, show_observations=False, show_predictions=False) - # env_renderer.gl.save_image('./misc/Fames2/flatland_frame_{:04d}.png'.format(step)) + env_renderer.gl.save_image('./misc/Fames2/flatland_frame_{:04d}.png'.format(step)) frame_step += 1 # Update replay buffer and train agent for a in range(env.get_num_agents()): diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py index c877f9824526889a7d355d198f0422956324ad72..f6d0c78f9f37b9d179a2c42776d1b7411600d263 100644 --- a/flatland/envs/malfunction_generators.py +++ b/flatland/envs/malfunction_generators.py @@ -6,10 +6,12 @@ import msgpack import numpy as np from numpy.random.mtrand import RandomState -from flatland.envs.agent_utils import EnvAgent +from flatland.envs.agent_utils import EnvAgent, RailAgentStatus Malfunction = NamedTuple('Malfunction', [('num_broken_steps', int)]) -MalfunctionGenerator = Callable[[EnvAgent], Optional[Malfunction]] +MalfunctionParameters = NamedTuple('MalfunctionParameters', + [('malfunction_rate', float), ('min_duration', int), ('max_duration', int)]) +MalfunctionGenerator = Callable[[EnvAgent, RandomState, bool], Optional[Malfunction]] MalfunctionProcessData = NamedTuple('MalfunctionProcessData', [('malfunction_rate', float), ('min_duration', int), ('max_duration', int)]) @@ -36,7 +38,7 @@ def malfunction_from_file(filename: str) -> Tuple[MalfunctionGenerator, Malfunct Returns ------- - Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken + generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken """ with open(filename, "rb") as file_in: load_data = file_in.read() @@ -57,7 +59,7 @@ def malfunction_from_file(filename: str) -> Tuple[MalfunctionGenerator, Malfunct min_number_of_steps_broken = 0 max_number_of_steps_broken = 0 - def generator(agent: EnvAgent, np_random: RandomState) -> Optional[Malfunction]: + def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]: """ Generate malfunctions for agents Parameters @@ -69,6 +71,11 @@ def malfunction_from_file(filename: str) -> Tuple[MalfunctionGenerator, Malfunct ------- int: Number of time steps an agent is broken """ + + # Dummy reset function as we don't implement specific seeding here + if reset: + return Malfunction(0) + if agent.malfunction_data['malfunction'] < 1: if np_random.rand() < _malfunction_prob(mean_malfunction_rate): num_broken_steps = np_random.randint(min_number_of_steps_broken, @@ -80,26 +87,27 @@ def malfunction_from_file(filename: str) -> Tuple[MalfunctionGenerator, Malfunct max_number_of_steps_broken) -def malfunction_from_params(parameters: dict) -> Tuple[MalfunctionGenerator, MalfunctionProcessData]: +def malfunction_from_params(parameters: MalfunctionParameters) -> Tuple[MalfunctionGenerator, MalfunctionProcessData]: """ Utility to load malfunction from parameters Parameters ---------- - parameters containing - malfunction_rate : float how many time steps it takes for a sinlge agent befor it breaks - min_duration : int minimal duration of a failure - max_number_of_steps_broken : int maximal duration of a failure + + parameters : contains all the parameters of the malfunction + malfunction_rate : float how many time steps it takes for a sinlge agent befor it breaks + min_duration : int minimal duration of a failure + max_number_of_steps_broken : int maximal duration of a failure Returns ------- - Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken + generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken """ - mean_malfunction_rate = parameters['malfunction_rate'] - min_number_of_steps_broken = parameters['min_duration'] - max_number_of_steps_broken = parameters['max_duration'] + mean_malfunction_rate = parameters.malfunction_rate + min_number_of_steps_broken = parameters.min_duration + max_number_of_steps_broken = parameters.max_duration - def generator(agent: EnvAgent, np_random: RandomState) -> Optional[Malfunction]: + def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]: """ Generate malfunctions for agents Parameters @@ -111,6 +119,11 @@ def malfunction_from_params(parameters: dict) -> Tuple[MalfunctionGenerator, Mal ------- int: Number of time steps an agent is broken """ + + # Dummy reset function as we don't implement specific seeding here + if reset: + return Malfunction(0) + if agent.malfunction_data['malfunction'] < 1: if np_random.rand() < _malfunction_prob(mean_malfunction_rate): num_broken_steps = np_random.randint(min_number_of_steps_broken, @@ -124,15 +137,15 @@ def malfunction_from_params(parameters: dict) -> Tuple[MalfunctionGenerator, Mal def no_malfunction_generator() -> Tuple[MalfunctionGenerator, MalfunctionProcessData]: """ - Utility to load malfunction from parameters + Malfunction generator which generates no malfunctions Parameters ---------- - input_file : Pickle file generated by env.save() or editor + Nothing Returns ------- - Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken + generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken """ # Mean malfunction in number of time steps mean_malfunction_rate = 0. @@ -141,8 +154,68 @@ def no_malfunction_generator() -> Tuple[MalfunctionGenerator, MalfunctionProcess min_number_of_steps_broken = 0 max_number_of_steps_broken = 0 - def generator(agent: EnvAgent, np_random: RandomState) -> Optional[Malfunction]: + def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]: return Malfunction(0) return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken) + + +def single_malfunction_generator(earlierst_malfunction: int, malfunction_duration: int) -> Tuple[ + MalfunctionGenerator, MalfunctionProcessData]: + """ + Malfunction generator which guarantees exactly one malfunction during an episode of an ACTIVE agent. + + Parameters + ---------- + malfunction_duration: The duration of the single malfunction + + Returns + ------- + generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken + """ + # Mean malfunction in number of time steps + mean_malfunction_rate = 0. + + # Uniform distribution parameters for malfunction duration + min_number_of_steps_broken = 0 + max_number_of_steps_broken = 0 + + # Keep track of the total number of malfunctions in the env + global_nr_malfunctions = 0 + + # Malfunction calls per agent + malfunction_calls = dict() + + def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]: + # We use the global variable to assure only a single malfunction in the env + nonlocal global_nr_malfunctions + nonlocal malfunction_calls + + # Reset malfunciton generator + if reset: + nonlocal global_nr_malfunctions + nonlocal malfunction_calls + global_nr_malfunctions = 0 + malfunction_calls = dict() + return Malfunction(0) + + # No more malfunctions if we already had one, ignore all updates + if global_nr_malfunctions > 0: + return Malfunction(0) + + # Update number of calls per agent + if agent.handle in malfunction_calls: + malfunction_calls[agent.handle] += 1 + else: + malfunction_calls[agent.handle] = 1 + + # Break an agent that is active at the time of the malfunction + if agent.status == RailAgentStatus.ACTIVE and malfunction_calls[agent.handle] >= earlierst_malfunction: + global_nr_malfunctions += 1 + return Malfunction(malfunction_duration) + else: + return Malfunction(0) + + return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken, + max_number_of_steps_broken) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index cab96ab0a6ac35db74ff74499c334867980d0a0a..5ec4db3c27749327e539b94543758aa31c69c707 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -370,6 +370,9 @@ class RailEnv(Environment): self.obs_builder.reset() self.distance_map.reset(self.agents, self.rail) + # Reset the malfunction generator + self.malfunction_generator(reset=True) + info_dict: Dict = { 'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)}, 'malfunction': { diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index df398a21583c7c24bc4cb5e1a08e7a517ea3483c..2d3fbd42d353e57b8d510c5bc3e0ef8118bdecaf 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -8,7 +8,7 @@ from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_new_position from flatland.envs.agent_utils import RailAgentStatus -from flatland.envs.malfunction_generators import malfunction_from_params +from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.schedule_generators import random_schedule_generator @@ -67,9 +67,10 @@ class SingleAgentNavigationObs(ObservationBuilder): def test_malfunction_process(): # Set fixed malfunction duration for this test - stochastic_data = {'malfunction_rate': 1, - 'min_duration': 3, - 'max_duration': 3} + stochastic_data = MalfunctionParameters(malfunction_rate=1, # Rate of malfunction occurence + min_duration=3, # Minimal duration of malfunction + max_duration=3 # Max duration of malfunction + ) rail, rail_map = make_simple_rail2() @@ -120,9 +121,10 @@ def test_malfunction_process(): def test_malfunction_process_statistically(): """Tests hat malfunctions are produced by stochastic_data!""" # Set fixed malfunction duration for this test - stochastic_data = {'malfunction_rate': 5, - 'min_duration': 5, - 'max_duration': 5} + stochastic_data = MalfunctionParameters(malfunction_rate=5, # Rate of malfunction occurence + min_duration=5, # Minimal duration of malfunction + max_duration=5 # Max duration of malfunction + ) rail, rail_map = make_simple_rail2() @@ -166,9 +168,10 @@ def test_malfunction_process_statistically(): def test_malfunction_before_entry(): """Tests that malfunctions are working properly for agents before entering the environment!""" # Set fixed malfunction duration for this test - stochastic_data = {'malfunction_rate': 2, - 'min_duration': 10, - 'max_duration': 10} + stochastic_data = MalfunctionParameters(malfunction_rate=2, # Rate of malfunction occurence + min_duration=10, # Minimal duration of malfunction + max_duration=10 # Max duration of malfunction + ) rail, rail_map = make_simple_rail2() @@ -212,9 +215,10 @@ def test_malfunction_values_and_behavior(): rail, rail_map = make_simple_rail2() action_dict: Dict[int, RailEnvActions] = {} - stochastic_data = {'malfunction_rate': 0.001, - 'min_duration': 10, - 'max_duration': 10} + stochastic_data = MalfunctionParameters(malfunction_rate=0.001, # Rate of malfunction occurence + min_duration=10, # Minimal duration of malfunction + max_duration=10 # Max duration of malfunction + ) env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), @@ -237,10 +241,10 @@ def test_malfunction_values_and_behavior(): def test_initial_malfunction(): - stochastic_data = {'malfunction_rate': 1000, # Rate of malfunction occurence - 'min_duration': 2, # Minimal duration of malfunction - 'max_duration': 5 # Max duration of malfunction - } + stochastic_data = MalfunctionParameters(malfunction_rate=1000, # Rate of malfunction occurence + min_duration=2, # Minimal duration of malfunction + max_duration=5 # Max duration of malfunction + ) rail, rail_map = make_simple_rail2() @@ -308,12 +312,6 @@ def test_initial_malfunction(): def test_initial_malfunction_stop_moving(): - stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents - 'malfunction_rate': 70, # Rate of malfunction occurence - 'min_duration': 2, # Minimal duration of malfunction - 'max_duration': 5 # Max duration of malfunction - } - rail, rail_map = make_simple_rail2() env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), @@ -394,11 +392,10 @@ def test_initial_malfunction_do_nothing(): random.seed(0) np.random.seed(0) - stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents - 'malfunction_rate': 70, # Rate of malfunction occurence - 'min_duration': 2, # Minimal duration of malfunction - 'max_duration': 5 # Max duration of malfunction - } + stochastic_data = MalfunctionParameters(malfunction_rate=70, # Rate of malfunction occurence + min_duration=2, # Minimal duration of malfunction + max_duration=5 # Max duration of malfunction + ) rail, rail_map = make_simple_rail2() @@ -479,10 +476,6 @@ def test_initial_malfunction_do_nothing(): def tests_random_interference_from_outside(): """Tests that malfunctions are produced by stochastic_data!""" # Set fixed malfunction duration for this test - stochastic_data = {'malfunction_rate': 1, - 'min_duration': 10, - 'max_duration': 10} - rail, rail_map = make_simple_rail2() env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, random_seed=1) @@ -537,9 +530,6 @@ def test_last_malfunction_step(): """ # Set fixed malfunction duration for this test - stochastic_data = {'malfunction_rate': 5, - 'min_duration': 4, - 'max_duration': 4} rail, rail_map = make_simple_rail2() diff --git a/tests/test_malfunction_generators.py b/tests/test_malfunction_generators.py index 075edc139b6786933a32c915998c0fe56cb7a76c..51839babe563943a609492bcad64243d36105b5c 100644 --- a/tests/test_malfunction_generators.py +++ b/tests/test_malfunction_generators.py @@ -1,8 +1,5 @@ -import numpy as np - -from flatland.core.env_observation_builder import ObservationBuilder -from flatland.core.grid.grid4_utils import get_new_position -from flatland.envs.malfunction_generators import malfunction_from_params, malfunction_from_file +from flatland.envs.malfunction_generators import malfunction_from_params, malfunction_from_file, \ + single_malfunction_generator, MalfunctionParameters from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.schedule_generators import random_schedule_generator @@ -16,11 +13,10 @@ def test_malfanction_from_params(): ------- """ - stochastic_data = {'malfunction_rate': 1000, # Rate of malfunction occurence - 'min_duration': 2, # Minimal duration of malfunction - 'max_duration': 5 # Max duration of malfunction - } - + stochastic_data = MalfunctionParameters(malfunction_rate=1000, # Rate of malfunction occurence + min_duration=2, # Minimal duration of malfunction + max_duration=5 # Max duration of malfunction + ) rail, rail_map = make_simple_rail2() env = RailEnv(width=25, @@ -43,10 +39,10 @@ def test_malfanction_to_and_from_file(): ------- """ - stochastic_data = {'malfunction_rate': 1000, # Rate of malfunction occurence - 'min_duration': 2, # Minimal duration of malfunction - 'max_duration': 5 # Max duration of malfunction - } + stochastic_data = MalfunctionParameters(malfunction_rate=1000, # Rate of malfunction occurence + min_duration=2, # Minimal duration of malfunction + max_duration=5 # Max duration of malfunction + ) rail, rail_map = make_simple_rail2() @@ -62,17 +58,50 @@ def test_malfanction_to_and_from_file(): malfunction_generator, malfunction_process_data = malfunction_from_file("./malfunction_saving_loading_tests.pkl") env2 = RailEnv(width=25, - height=30, - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=10, - malfunction_generator_and_process_data=malfunction_from_params(stochastic_data) - ) + height=30, + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=10, + malfunction_generator_and_process_data=malfunction_from_params(stochastic_data) + ) env2.reset() - assert env2.malfunction_process_data == env.malfunction_process_data + assert env2.malfunction_process_data == env.malfunction_process_data assert env2.malfunction_process_data.malfunction_rate == 1000 assert env2.malfunction_process_data.min_duration == 2 assert env2.malfunction_process_data.max_duration == 5 + +def test_single_malfunction_generator(): + """ + Test single malfunction generator + Returns + ------- + + """ + + rail, rail_map = make_simple_rail2() + env = RailEnv(width=25, + height=30, + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=10, + malfunction_generator_and_process_data=single_malfunction_generator(earlierst_malfunction=10, + malfunction_duration=5) + ) + for test in range(10): + env.reset() + action_dict = dict() + tot_malfunctions = 0 + print(test) + for i in range(10): + for agent in env.agents: + # Go forward all the time + action_dict[agent.handle] = RailEnvActions(2) + + env.step(action_dict) + for agent in env.agents: + # Go forward all the time + tot_malfunctions += agent.malfunction_data['nr_malfunctions'] + assert tot_malfunctions == 1