Commit 117aa7d1 authored by Erik Nygren's avatar Erik Nygren 🚅
Browse files

Introducing malfunction_generators

This resolves issue #273
parent 91a67e87
Pipeline #2676 failed with stages
in 10 minutes and 19 seconds
......@@ -149,7 +149,7 @@ env = RailEnv(width=width,
rail_generator=rail_generator,
schedule_generator=schedule_generator,
number_of_agents=nr_trains,
stochastic_data=stochastic_data, # Malfunction data generator
malfunction_generator=stochastic_data, # Malfunction data generator
obs_builder_object=observation_builder,
remove_agents_at_target=True # Removes agents at the end of their journey to make space for others
)
......
......@@ -14,10 +14,8 @@ def run_benchmark():
np.random.seed(1)
# Example generate a random rail
env = RailEnv(width=15, height=15,
rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12),
schedule_generator=complex_schedule_generator(),
number_of_agents=5)
env = RailEnv(width=15, height=15, rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12),
schedule_generator=complex_schedule_generator(), number_of_agents=5)
env.reset()
n_trials = 20
......
......@@ -28,10 +28,7 @@ class SimpleObs(ObservationBuilder):
def main():
env = RailEnv(width=7,
height=7,
rail_generator=random_rail_generator(),
number_of_agents=3,
env = RailEnv(width=7, height=7, rail_generator=random_rail_generator(), number_of_agents=3,
obs_builder_object=SimpleObs())
env.reset()
......
......@@ -76,13 +76,10 @@ def main(args):
else:
assert False, "unhandled option"
env = RailEnv(width=7,
height=7,
env = RailEnv(width=7, height=7,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
seed=1),
schedule_generator=complex_schedule_generator(),
number_of_agents=1,
obs_builder_object=SingleAgentNavigationObs())
seed=1), schedule_generator=complex_schedule_generator(),
number_of_agents=1, obs_builder_object=SingleAgentNavigationObs())
obs, info = env.reset()
env_renderer = RenderTool(env, gl="PILSVG")
......
......@@ -122,13 +122,10 @@ def main(args):
custom_obs_builder = ObservePredictions(custom_predictor)
# Initiate Environment
env = RailEnv(width=10,
height=10,
env = RailEnv(width=10, height=10,
rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999,
seed=1),
schedule_generator=complex_schedule_generator(),
number_of_agents=3,
obs_builder_object=custom_obs_builder)
seed=1), schedule_generator=complex_schedule_generator(),
number_of_agents=3, obs_builder_object=custom_obs_builder)
obs, info = env.reset()
env_renderer = RenderTool(env, gl="PILSVG")
......
......@@ -43,10 +43,7 @@ def custom_schedule_generator() -> ScheduleGenerator:
return generator
env = RailEnv(width=6,
height=4,
rail_generator=custom_rail_generator(),
schedule_generator=custom_schedule_generator(),
env = RailEnv(width=6, height=4, rail_generator=custom_rail_generator(), schedule_generator=custom_schedule_generator(),
number_of_agents=1)
env.reset()
......
......@@ -30,21 +30,16 @@ speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
env = RailEnv(width=100,
height=100,
rail_generator=sparse_rail_generator(max_num_cities=30,
# Number of cities in map (where train stations are)
seed=14, # Random seed
grid_mode=False,
max_rails_between_cities=2,
max_rails_in_city=8,
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=100,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=GlobalObsForRailEnv(),
remove_agents_at_target=True
)
env = RailEnv(width=100, height=100, rail_generator=sparse_rail_generator(max_num_cities=30,
# Number of cities in map (where train stations are)
seed=14, # Random seed
grid_mode=False,
max_rails_between_cities=2,
max_rails_in_city=8,
),
schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=100,
obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=stochastic_data,
remove_agents_at_target=True)
# RailEnv.DEPOT_POSITION = lambda agent, agent_handle : (agent_handle % env.height,0)
......
......@@ -72,15 +72,9 @@ observation_builder = GlobalObsForRailEnv()
# observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
# Construct the enviornment with the given observation, generataors, predictors, and stochastic data
env = RailEnv(width=width,
height=height,
rail_generator=rail_generator,
schedule_generator=schedule_generator,
number_of_agents=nr_trains,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=observation_builder,
remove_agents_at_target=True # Removes agents at the end of their journey to make space for others
)
env = RailEnv(width=width, height=height, rail_generator=rail_generator, schedule_generator=schedule_generator,
number_of_agents=nr_trains, obs_builder_object=observation_builder, malfunction_generator=stochastic_data,
remove_agents_at_target=True)
env.reset()
# Initiate the renderer
......
......@@ -9,10 +9,7 @@ specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)],
[(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)],
[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)]]
env = RailEnv(width=6,
height=4,
rail_generator=rail_from_manual_specifications_generator(specs),
number_of_agents=1)
env = RailEnv(width=6, height=4, rail_generator=rail_from_manual_specifications_generator(specs), number_of_agents=1)
env.reset()
......
......@@ -23,8 +23,7 @@ transition_probability = [1.0, # empty cell - Case 0
1.0] # Case 10 - mirrored switch
# Example generate a random rail
env = RailEnv(width=10,
height=10,
env = RailEnv(width=10, height=10,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=3)
......
......@@ -11,11 +11,9 @@ from flatland.utils.rendertools import RenderTool
random.seed(1)
np.random.seed(1)
env = RailEnv(width=7,
height=7,
env = RailEnv(width=7, height=7,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=1),
schedule_generator=complex_schedule_generator(),
number_of_agents=2,
schedule_generator=complex_schedule_generator(), number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2))
env.reset()
......
......@@ -14,12 +14,9 @@ np.random.seed(1)
TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2)
env = RailEnv(width=20,
height=20,
env = RailEnv(width=20, height=20,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=2, min_dist=8, max_dist=99999, seed=1),
schedule_generator=complex_schedule_generator(),
obs_builder_object=TreeObservation,
number_of_agents=3)
schedule_generator=complex_schedule_generator(), number_of_agents=3, obs_builder_object=TreeObservation)
env.reset()
env_renderer = RenderTool(env, gl="PILSVG", )
......
......@@ -18,16 +18,11 @@ from flatland.utils.rendertools import RenderTool
@click.command()
def demo(args=None):
"""Demo script to check installation"""
env = RailEnv(
width=15,
height=15,
rail_generator=complex_rail_generator(
nr_start_goal=10,
nr_extra=1,
min_dist=8,
max_dist=99999),
schedule_generator=complex_schedule_generator(),
number_of_agents=5)
env = RailEnv(width=15, height=15, rail_generator=complex_rail_generator(
nr_start_goal=10,
nr_extra=1,
min_dist=8,
max_dist=99999), schedule_generator=complex_schedule_generator(), number_of_agents=5)
env._max_episode_steps = int(15 * (env.width + env.height))
env_renderer = RenderTool(env)
......
"""Malfunction generators for rail systems"""
from typing import Tuple, List, Callable
import msgpack
MalfunctionGenerator = Callable[[], Tuple[float, int, int]]
def malfunction_from_file(filename) -> MalfunctionGenerator:
"""
Utility to load pickle file
Parameters
----------
input_file : Pickle file generated by env.save() or editor
Returns
-------
Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
"""
def generator():
with open(filename, "rb") as file_in:
load_data = file_in.read()
data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8')
if "malfunction" in data:
# Mean malfunction in number of time steps
mean_malfunction_rate = data["malfunction"]["malfunction_rate"]
# Uniform distribution parameters for malfunction duration
min_number_of_steps_broken = data["malfunction"]["min_duration"]
max_number_of_steps_broken = data["malfunction"]["max_duration"]
agents_speed = None
return mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
return generator
def malfunction_from_params(parameters) -> MalfunctionGenerator:
"""
Utility to load malfunction from parameters
Parameters
----------
parameters containing
malfunction_rate : float how many time steps it takes for a sinlge agent befor it breaks
min_duration : int minimal duration of a failure
max_number_of_steps_broken : int maximal duration of a failure
Returns
-------
Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
"""
def generator():
mean_malfunction_rate = parameters['malfunction_rate']
min_number_of_steps_broken = parameters['min_duration']
max_number_of_steps_broken = parameters['max_duration']
return mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
return generator
def no_malfunction_generator() -> MalfunctionGenerator:
"""
Utility to load malfunction from parameters
Parameters
----------
input_file : Pickle file generated by env.save() or editor
Returns
-------
Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
"""
def generator():
return 0, 0, 0
return generator
......@@ -19,6 +19,7 @@ from flatland.core.grid.grid_utils import IntVector2D
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus
from flatland.envs.distance_map import DistanceMap
from flatland.envs.malfunction_generators import MalfunctionGenerator, no_malfunction_generator
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_generators import random_rail_generator, RailGenerator
from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
......@@ -111,17 +112,15 @@ class RailEnv(Environment):
stop_penalty = 0 # penalty for stopping a moving agent
start_penalty = 0 # penalty for starting a stopped agent
def __init__(self,
width,
def __init__(self, width,
height,
rail_generator: RailGenerator = random_rail_generator(),
schedule_generator: ScheduleGenerator = random_schedule_generator(),
number_of_agents=1,
obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(),
stochastic_data=None,
malfunction_generator: MalfunctionGenerator = no_malfunction_generator(),
remove_agents_at_target=True,
random_seed=1
):
random_seed=1):
"""
Environment init.
......@@ -161,6 +160,7 @@ class RailEnv(Environment):
self.rail_generator: RailGenerator = rail_generator
self.schedule_generator: ScheduleGenerator = schedule_generator
self.malfunction_generator: MalfunctionGenerator = malfunction_generator
self.rail: Optional[GridTransitionMap] = None
self.width = width
self.height = height
......@@ -196,19 +196,8 @@ class RailEnv(Environment):
self._seed(seed=random_seed)
# Stochastic train malfunctioning parameters
if stochastic_data is not None:
mean_malfunction_rate = stochastic_data['malfunction_rate']
malfunction_min_duration = stochastic_data['min_duration']
malfunction_max_duration = stochastic_data['max_duration']
else:
mean_malfunction_rate = 0.
malfunction_min_duration = 0.
malfunction_max_duration = 0.
# Mean malfunction in number of time steps
mean_malfunction_rate, malfunction_min_duration, malfunction_max_duration = self.malfunction_generator()
self.mean_malfunction_rate = mean_malfunction_rate
# Uniform distribution parameters for malfunction duration
self.min_number_of_steps_broken = malfunction_min_duration
self.max_number_of_steps_broken = malfunction_max_duration
......@@ -359,6 +348,12 @@ class RailEnv(Environment):
else:
self._max_episode_steps = self.compute_max_episode_steps(width=self.width, height=self.height)
# Stochastic train malfunctioning parameters
mean_malfunction_rate, malfunction_min_duration, malfunction_max_duration = self.malfunction_generator()
self.mean_malfunction_rate = mean_malfunction_rate
self.min_number_of_steps_broken = malfunction_min_duration
self.max_number_of_steps_broken = malfunction_max_duration
self.agent_positions = np.full((self.height, self.width), False)
self.restart_agents()
......@@ -837,22 +832,41 @@ class RailEnv(Environment):
grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static]
agent_data = [agent.to_list() for agent in self.agents]
malfunction_data = {"malfunction_rate": self.mean_malfunction_rate,
"min_duration": self.min_number_of_steps_broken,
"max_duration": self.max_number_of_steps_broken}
msgpack.packb(grid_data, use_bin_type=True)
msgpack.packb(agent_data, use_bin_type=True)
msgpack.packb(agent_static_data, use_bin_type=True)
msg_data = {
"grid": grid_data,
"agents_static": agent_static_data,
"agents": agent_data}
"agents": agent_data,
"malfunction": malfunction_data}
return msgpack.packb(msg_data, use_bin_type=True)
def get_agent_state_msg(self):
def get_full_state_dist_msg(self):
"""
Returns agents information in msgpack object
Returns environment information with distance map information as msgpack object
"""
grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static]
agent_data = [agent.to_list() for agent in self.agents]
msgpack.packb(grid_data, use_bin_type=True)
msgpack.packb(agent_data, use_bin_type=True)
msgpack.packb(agent_static_data, use_bin_type=True)
distance_map_data = self.distance_map.get()
malfunction_data = {"malfunction_rate": self.mean_malfunction_rate,
"min_duration": self.min_number_of_steps_broken,
"max_duration": self.max_number_of_steps_broken}
msgpack.packb(distance_map_data, use_bin_type=True)
msg_data = {
"agents": agent_data}
"grid": grid_data,
"agents_static": agent_static_data,
"agents": agent_data,
"distance_map": distance_map_data,
"malfunction": malfunction_data}
return msgpack.packb(msg_data, use_bin_type=True)
def set_full_state_msg(self, msg_data):
......@@ -873,6 +887,12 @@ class RailEnv(Environment):
self.rail.height = self.height
self.rail.width = self.width
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
if "malfunction" in data:
# Mean malfunction in number of time steps
self.mean_malfunction_rate = data["malfunction"]["malfunction_rate"]
# Uniform distribution parameters for malfunction duration
self.min_number_of_steps_broken = data["malfunction"]["min_duration"]
self.max_number_of_steps_broken = data["malfunction"]["max_duration"]
def set_full_state_dist_msg(self, msg_data):
"""
......@@ -894,26 +914,12 @@ class RailEnv(Environment):
self.rail.height = self.height
self.rail.width = self.width
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
def get_full_state_dist_msg(self):
"""
Returns environment information with distance map information as msgpack object
"""
grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static]
agent_data = [agent.to_list() for agent in self.agents]
msgpack.packb(grid_data, use_bin_type=True)
msgpack.packb(agent_data, use_bin_type=True)
msgpack.packb(agent_static_data, use_bin_type=True)
distance_map_data = self.distance_map.get()
msgpack.packb(distance_map_data, use_bin_type=True)
msg_data = {
"grid": grid_data,
"agents_static": agent_static_data,
"agents": agent_data,
"distance_map": distance_map_data}
return msgpack.packb(msg_data, use_bin_type=True)
if "malfunction" in data:
# Mean malfunction in number of time steps
self.mean_malfunction_rate = data["malfunction"]["malfunction_rate"]
# Uniform distribution parameters for malfunction duration
self.min_number_of_steps_broken = data["malfunction"]["min_duration"]
self.max_number_of_steps_broken = data["malfunction"]["max_duration"]
def save(self, filename, save_distance_maps=False):
"""
......
......@@ -10,10 +10,7 @@ def load_flatland_environment_from_file(file_name, load_from_package=None, obs_b
obs_builder_object = TreeObsForRailEnv(
max_depth=2,
predictor=ShortestPathPredictorForRailEnv(max_depth=10))
environment = RailEnv(width=1,
height=1,
rail_generator=rail_from_file(file_name, load_from_package),
number_of_agents=1,
schedule_generator=schedule_from_file(file_name, load_from_package),
environment = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name, load_from_package),
schedule_generator=schedule_from_file(file_name, load_from_package), number_of_agents=1,
obs_builder_object=obs_builder_object)
return environment
......@@ -217,13 +217,9 @@ class FlatlandRemoteClient(object):
if self.verbose:
print("Current env path : ", test_env_file_path)
self.current_env_path = test_env_file_path
self.env = RailEnv(
width=1,
height=1,
rail_generator=rail_from_file(test_env_file_path),
schedule_generator=schedule_from_file(test_env_file_path),
obs_builder_object=obs_builder_object
)
self.env = RailEnv(width=1, height=1, rail_generator=rail_from_file(test_env_file_path),
schedule_generator=schedule_from_file(test_env_file_path),
obs_builder_object=obs_builder_object)
time_start = time.time()
local_observation, info = self.env.reset(
......@@ -246,8 +242,8 @@ class FlatlandRemoteClient(object):
_request['type'] = messages.FLATLAND_RL.ENV_STEP
_request['payload'] = {}
_request['payload']['action'] = action
# Relay the action in a non-blocking way to the server
# Relay the action in a non-blocking way to the server
# so that it can start doing an env.step on it in ~ parallel
self._remote_request(_request, blocking=False)
......
......@@ -273,7 +273,7 @@ class FlatlandRemoteEvaluationService:
)
if self.verbose:
print("Received Request : ", command)
message_queue_latency = time.time() - command["timestamp"]
self.update_running_mean_stats("message_queue_latency", message_queue_latency)
return command
......@@ -335,13 +335,9 @@ class FlatlandRemoteEvaluationService:
test_env_file_path
)
del self.env
self.env = RailEnv(
width=1,
height=1,
rail_generator=rail_from_file(test_env_file_path),
schedule_generator=schedule_from_file(test_env_file_path),
obs_builder_object=DummyObservationBuilder()
)
self.env = RailEnv(width=1, height=1, rail_generator=rail_from_file(test_env_file_path),
schedule_generator=schedule_from_file(test_env_file_path),
obs_builder_object=DummyObservationBuilder())
if self.begin_simulation:
# If begin simulation has already been initialized
......
......@@ -24,10 +24,7 @@ class EditorMVC(object):
""" Create an Editor MVC assembly around a railenv, or create one if None.
"""
if env is None:
env = RailEnv(width=10,
height=10,
rail_generator=empty_rail_generator(),
number_of_agents=0,
env = RailEnv(width=10, height=10, rail_generator=empty_rail_generator(), number_of_agents=0,
obs_builder_object=TreeObsForRailEnv(max_depth=2))
env.reset()
......@@ -669,11 +666,8 @@ class EditorModel(object):
fnMethod = complex_rail_generator(nr_start_goal=nAgents, nr_extra=20, min_dist=12, seed=int(time.time()))
if env is None:
self.env = RailEnv(width=self.regen_size_width,
height=self.regen_size_height,
rail_generator=fnMethod,
number_of_agents=nAgents,
obs_builder_object=TreeObsForRailEnv(max_depth=2))
self.env = RailEnv(width=self.regen_size_width, height=self.regen_size_height, rail_generator=fnMethod,
number_of_agents=nAgents, obs_builder_object=TreeObsForRailEnv(max_depth=2))
else:
self.env = env
self.env.reset(regenerate_rail=True)
......
......@@ -25,14 +25,10 @@ def test_walker():
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2,
predictor=ShortestPathPredictorForRailEnv(max_depth=10)),
)
predictor=ShortestPathPredictorForRailEnv(max_depth=10)))
# reset to initialize agents_static
env.reset()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment