Commit c939968d authored by Erik Nygren's avatar Erik Nygren 🚅 Committed by Christian Eichenberger
Browse files

298 sinlge malfunction generator

parent be442f15
......@@ -2,7 +2,7 @@ import time
import numpy as np
from flatland.envs.malfunction_generators import malfunction_from_params
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
......@@ -16,12 +16,10 @@ np.random.seed(1)
# Training on simple small tasks is the best way to get familiar with the environment
# Use a the malfunction generator to break agents from time to time
stochastic_data = {'prop_malfunction': 0.3, # Percentage of defective agents
'malfunction_rate': 30, # Rate of malfunction occurence
'min_duration': 3, # Minimal duration of malfunction
'max_duration': 20 # Max duration of malfunction
}
stochastic_data = MalfunctionParameters(malfunction_rate=30, # Rate of malfunction occurence
min_duration=3, # Minimal duration of malfunction
max_duration=20 # Max duration of malfunction
)
# Custom observation builder
TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
......
......@@ -3,7 +3,7 @@ import numpy as np
# In Flatland you can use custom observation builders and predicitors
# Observation builders generate the observation needed by the controller
# Preditctors can be used to do short time prediction which can help in avoiding conflicts in the network
from flatland.envs.malfunction_generators import malfunction_from_params
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
from flatland.envs.observations import GlobalObsForRailEnv
# First of all we import the Flatland rail environment
from flatland.envs.rail_env import RailEnv
......@@ -62,11 +62,10 @@ schedule_generator = sparse_schedule_generator(speed_ration_map)
# We can furthermore pass stochastic data to the RailEnv constructor which will allow for stochastic malfunctions
# during an episode.
stochastic_data = {'malfunction_rate': 12000, # Rate of malfunction occurence of single agent
'min_duration': 15, # Minimal duration of malfunction
'max_duration': 50 # Max duration of malfunction
}
stochastic_data = MalfunctionParameters(malfunction_rate=10000, # Rate of malfunction occurence
min_duration=15, # Minimal duration of malfunction
max_duration=50 # Max duration of malfunction
)
# Custom observation builder without predictor
observation_builder = GlobalObsForRailEnv()
......@@ -256,7 +255,7 @@ for step in range(500):
next_obs, all_rewards, done, _ = env.step(action_dict)
env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
# env_renderer.gl.save_image('./misc/Fames2/flatland_frame_{:04d}.png'.format(step))
env_renderer.gl.save_image('./misc/Fames2/flatland_frame_{:04d}.png'.format(step))
frame_step += 1
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
......
......@@ -6,10 +6,12 @@ import msgpack
import numpy as np
from numpy.random.mtrand import RandomState
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
Malfunction = NamedTuple('Malfunction', [('num_broken_steps', int)])
MalfunctionGenerator = Callable[[EnvAgent], Optional[Malfunction]]
MalfunctionParameters = NamedTuple('MalfunctionParameters',
[('malfunction_rate', float), ('min_duration', int), ('max_duration', int)])
MalfunctionGenerator = Callable[[EnvAgent, RandomState, bool], Optional[Malfunction]]
MalfunctionProcessData = NamedTuple('MalfunctionProcessData',
[('malfunction_rate', float), ('min_duration', int), ('max_duration', int)])
......@@ -36,7 +38,7 @@ def malfunction_from_file(filename: str) -> Tuple[MalfunctionGenerator, Malfunct
Returns
-------
Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
"""
with open(filename, "rb") as file_in:
load_data = file_in.read()
......@@ -57,7 +59,7 @@ def malfunction_from_file(filename: str) -> Tuple[MalfunctionGenerator, Malfunct
min_number_of_steps_broken = 0
max_number_of_steps_broken = 0
def generator(agent: EnvAgent, np_random: RandomState) -> Optional[Malfunction]:
def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
"""
Generate malfunctions for agents
Parameters
......@@ -69,6 +71,11 @@ def malfunction_from_file(filename: str) -> Tuple[MalfunctionGenerator, Malfunct
-------
int: Number of time steps an agent is broken
"""
# Dummy reset function as we don't implement specific seeding here
if reset:
return Malfunction(0)
if agent.malfunction_data['malfunction'] < 1:
if np_random.rand() < _malfunction_prob(mean_malfunction_rate):
num_broken_steps = np_random.randint(min_number_of_steps_broken,
......@@ -80,26 +87,27 @@ def malfunction_from_file(filename: str) -> Tuple[MalfunctionGenerator, Malfunct
max_number_of_steps_broken)
def malfunction_from_params(parameters: dict) -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
def malfunction_from_params(parameters: MalfunctionParameters) -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
"""
Utility to load malfunction from parameters
Parameters
----------
parameters containing
malfunction_rate : float how many time steps it takes for a sinlge agent befor it breaks
min_duration : int minimal duration of a failure
max_number_of_steps_broken : int maximal duration of a failure
parameters : contains all the parameters of the malfunction
malfunction_rate : float how many time steps it takes for a sinlge agent befor it breaks
min_duration : int minimal duration of a failure
max_number_of_steps_broken : int maximal duration of a failure
Returns
-------
Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
"""
mean_malfunction_rate = parameters['malfunction_rate']
min_number_of_steps_broken = parameters['min_duration']
max_number_of_steps_broken = parameters['max_duration']
mean_malfunction_rate = parameters.malfunction_rate
min_number_of_steps_broken = parameters.min_duration
max_number_of_steps_broken = parameters.max_duration
def generator(agent: EnvAgent, np_random: RandomState) -> Optional[Malfunction]:
def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
"""
Generate malfunctions for agents
Parameters
......@@ -111,6 +119,11 @@ def malfunction_from_params(parameters: dict) -> Tuple[MalfunctionGenerator, Mal
-------
int: Number of time steps an agent is broken
"""
# Dummy reset function as we don't implement specific seeding here
if reset:
return Malfunction(0)
if agent.malfunction_data['malfunction'] < 1:
if np_random.rand() < _malfunction_prob(mean_malfunction_rate):
num_broken_steps = np_random.randint(min_number_of_steps_broken,
......@@ -124,15 +137,15 @@ def malfunction_from_params(parameters: dict) -> Tuple[MalfunctionGenerator, Mal
def no_malfunction_generator() -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
"""
Utility to load malfunction from parameters
Malfunction generator which generates no malfunctions
Parameters
----------
input_file : Pickle file generated by env.save() or editor
Nothing
Returns
-------
Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
"""
# Mean malfunction in number of time steps
mean_malfunction_rate = 0.
......@@ -141,8 +154,68 @@ def no_malfunction_generator() -> Tuple[MalfunctionGenerator, MalfunctionProcess
min_number_of_steps_broken = 0
max_number_of_steps_broken = 0
def generator(agent: EnvAgent, np_random: RandomState) -> Optional[Malfunction]:
def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
return Malfunction(0)
return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
max_number_of_steps_broken)
def single_malfunction_generator(earlierst_malfunction: int, malfunction_duration: int) -> Tuple[
MalfunctionGenerator, MalfunctionProcessData]:
"""
Malfunction generator which guarantees exactly one malfunction during an episode of an ACTIVE agent.
Parameters
----------
malfunction_duration: The duration of the single malfunction
Returns
-------
generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
"""
# Mean malfunction in number of time steps
mean_malfunction_rate = 0.
# Uniform distribution parameters for malfunction duration
min_number_of_steps_broken = 0
max_number_of_steps_broken = 0
# Keep track of the total number of malfunctions in the env
global_nr_malfunctions = 0
# Malfunction calls per agent
malfunction_calls = dict()
def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
# We use the global variable to assure only a single malfunction in the env
nonlocal global_nr_malfunctions
nonlocal malfunction_calls
# Reset malfunciton generator
if reset:
nonlocal global_nr_malfunctions
nonlocal malfunction_calls
global_nr_malfunctions = 0
malfunction_calls = dict()
return Malfunction(0)
# No more malfunctions if we already had one, ignore all updates
if global_nr_malfunctions > 0:
return Malfunction(0)
# Update number of calls per agent
if agent.handle in malfunction_calls:
malfunction_calls[agent.handle] += 1
else:
malfunction_calls[agent.handle] = 1
# Break an agent that is active at the time of the malfunction
if agent.status == RailAgentStatus.ACTIVE and malfunction_calls[agent.handle] >= earlierst_malfunction:
global_nr_malfunctions += 1
return Malfunction(malfunction_duration)
else:
return Malfunction(0)
return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
max_number_of_steps_broken)
......@@ -370,6 +370,9 @@ class RailEnv(Environment):
self.obs_builder.reset()
self.distance_map.reset(self.agents, self.rail)
# Reset the malfunction generator
self.malfunction_generator(reset=True)
info_dict: Dict = {
'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)},
'malfunction': {
......
......@@ -8,7 +8,7 @@ from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.malfunction_generators import malfunction_from_params
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
......@@ -67,9 +67,10 @@ class SingleAgentNavigationObs(ObservationBuilder):
def test_malfunction_process():
# Set fixed malfunction duration for this test
stochastic_data = {'malfunction_rate': 1,
'min_duration': 3,
'max_duration': 3}
stochastic_data = MalfunctionParameters(malfunction_rate=1, # Rate of malfunction occurence
min_duration=3, # Minimal duration of malfunction
max_duration=3 # Max duration of malfunction
)
rail, rail_map = make_simple_rail2()
......@@ -120,9 +121,10 @@ def test_malfunction_process():
def test_malfunction_process_statistically():
"""Tests hat malfunctions are produced by stochastic_data!"""
# Set fixed malfunction duration for this test
stochastic_data = {'malfunction_rate': 5,
'min_duration': 5,
'max_duration': 5}
stochastic_data = MalfunctionParameters(malfunction_rate=5, # Rate of malfunction occurence
min_duration=5, # Minimal duration of malfunction
max_duration=5 # Max duration of malfunction
)
rail, rail_map = make_simple_rail2()
......@@ -166,9 +168,10 @@ def test_malfunction_process_statistically():
def test_malfunction_before_entry():
"""Tests that malfunctions are working properly for agents before entering the environment!"""
# Set fixed malfunction duration for this test
stochastic_data = {'malfunction_rate': 2,
'min_duration': 10,
'max_duration': 10}
stochastic_data = MalfunctionParameters(malfunction_rate=2, # Rate of malfunction occurence
min_duration=10, # Minimal duration of malfunction
max_duration=10 # Max duration of malfunction
)
rail, rail_map = make_simple_rail2()
......@@ -212,9 +215,10 @@ def test_malfunction_values_and_behavior():
rail, rail_map = make_simple_rail2()
action_dict: Dict[int, RailEnvActions] = {}
stochastic_data = {'malfunction_rate': 0.001,
'min_duration': 10,
'max_duration': 10}
stochastic_data = MalfunctionParameters(malfunction_rate=0.001, # Rate of malfunction occurence
min_duration=10, # Minimal duration of malfunction
max_duration=10 # Max duration of malfunction
)
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
......@@ -237,10 +241,10 @@ def test_malfunction_values_and_behavior():
def test_initial_malfunction():
stochastic_data = {'malfunction_rate': 1000, # Rate of malfunction occurence
'min_duration': 2, # Minimal duration of malfunction
'max_duration': 5 # Max duration of malfunction
}
stochastic_data = MalfunctionParameters(malfunction_rate=1000, # Rate of malfunction occurence
min_duration=2, # Minimal duration of malfunction
max_duration=5 # Max duration of malfunction
)
rail, rail_map = make_simple_rail2()
......@@ -308,12 +312,6 @@ def test_initial_malfunction():
def test_initial_malfunction_stop_moving():
stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 70, # Rate of malfunction occurence
'min_duration': 2, # Minimal duration of malfunction
'max_duration': 5 # Max duration of malfunction
}
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
......@@ -394,11 +392,10 @@ def test_initial_malfunction_do_nothing():
random.seed(0)
np.random.seed(0)
stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 70, # Rate of malfunction occurence
'min_duration': 2, # Minimal duration of malfunction
'max_duration': 5 # Max duration of malfunction
}
stochastic_data = MalfunctionParameters(malfunction_rate=70, # Rate of malfunction occurence
min_duration=2, # Minimal duration of malfunction
max_duration=5 # Max duration of malfunction
)
rail, rail_map = make_simple_rail2()
......@@ -479,10 +476,6 @@ def test_initial_malfunction_do_nothing():
def tests_random_interference_from_outside():
"""Tests that malfunctions are produced by stochastic_data!"""
# Set fixed malfunction duration for this test
stochastic_data = {'malfunction_rate': 1,
'min_duration': 10,
'max_duration': 10}
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, random_seed=1)
......@@ -537,9 +530,6 @@ def test_last_malfunction_step():
"""
# Set fixed malfunction duration for this test
stochastic_data = {'malfunction_rate': 5,
'min_duration': 4,
'max_duration': 4}
rail, rail_map = make_simple_rail2()
......
import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.malfunction_generators import malfunction_from_params, malfunction_from_file
from flatland.envs.malfunction_generators import malfunction_from_params, malfunction_from_file, \
single_malfunction_generator, MalfunctionParameters
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
......@@ -16,11 +13,10 @@ def test_malfanction_from_params():
-------
"""
stochastic_data = {'malfunction_rate': 1000, # Rate of malfunction occurence
'min_duration': 2, # Minimal duration of malfunction
'max_duration': 5 # Max duration of malfunction
}
stochastic_data = MalfunctionParameters(malfunction_rate=1000, # Rate of malfunction occurence
min_duration=2, # Minimal duration of malfunction
max_duration=5 # Max duration of malfunction
)
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25,
......@@ -43,10 +39,10 @@ def test_malfanction_to_and_from_file():
-------
"""
stochastic_data = {'malfunction_rate': 1000, # Rate of malfunction occurence
'min_duration': 2, # Minimal duration of malfunction
'max_duration': 5 # Max duration of malfunction
}
stochastic_data = MalfunctionParameters(malfunction_rate=1000, # Rate of malfunction occurence
min_duration=2, # Minimal duration of malfunction
max_duration=5 # Max duration of malfunction
)
rail, rail_map = make_simple_rail2()
......@@ -62,17 +58,50 @@ def test_malfanction_to_and_from_file():
malfunction_generator, malfunction_process_data = malfunction_from_file("./malfunction_saving_loading_tests.pkl")
env2 = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=10,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data)
)
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=10,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data)
)
env2.reset()
assert env2.malfunction_process_data == env.malfunction_process_data
assert env2.malfunction_process_data == env.malfunction_process_data
assert env2.malfunction_process_data.malfunction_rate == 1000
assert env2.malfunction_process_data.min_duration == 2
assert env2.malfunction_process_data.max_duration == 5
def test_single_malfunction_generator():
"""
Test single malfunction generator
Returns
-------
"""
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=10,
malfunction_generator_and_process_data=single_malfunction_generator(earlierst_malfunction=10,
malfunction_duration=5)
)
for test in range(10):
env.reset()
action_dict = dict()
tot_malfunctions = 0
print(test)
for i in range(10):
for agent in env.agents:
# Go forward all the time
action_dict[agent.handle] = RailEnvActions(2)
env.step(action_dict)
for agent in env.agents:
# Go forward all the time
tot_malfunctions += agent.malfunction_data['nr_malfunctions']
assert tot_malfunctions == 1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment