Commit e3165bb0 authored by Erik Nygren's avatar Erik Nygren 🚅
Browse files

updated malfunction generator

parent ae31a7b8
Pipeline #2711 failed with stages
in 7 minutes and 29 seconds
......@@ -39,8 +39,7 @@ env = RailEnv(width=100, height=100, rail_generator=sparse_rail_generator(max_nu
max_rails_in_city=8,
),
schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=100,
obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=malfunction_from_params(stochastic_data),
remove_agents_at_target=True)
obs_builder_object=GlobalObsForRailEnv(), remove_agents_at_target=True)
# RailEnv.DEPOT_POSITION = lambda agent, agent_handle : (agent_handle % env.height,0)
......
......@@ -3,7 +3,7 @@ import numpy as np
# In Flatland you can use custom observation builders and predicitors
# Observation builders generate the observation needed by the controller
# Preditctors can be used to do short time prediction which can help in avoiding conflicts in the network
from flatland.envs.malfunction_generators import malfunction_from_params
from envs.malfunction_generators import malfunction_from_params
from flatland.envs.observations import GlobalObsForRailEnv
# First of all we import the Flatland rail environment
from flatland.envs.rail_env import RailEnv
......@@ -80,7 +80,7 @@ env = RailEnv(width=width,
schedule_generator=schedule_generator,
number_of_agents=nr_trains,
obs_builder_object=observation_builder,
malfunction_generator=malfunction_from_params(stochastic_data),
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
remove_agents_at_target=True)
env.reset()
......
......@@ -30,10 +30,7 @@ def load_flatland_environment_from_file(file_name: str,
obs_builder_object = TreeObsForRailEnv(
max_depth=2,
predictor=ShortestPathPredictorForRailEnv(max_depth=10))
environment = RailEnv(width=1, # will be overridden when loading from file
height=1, # will be overridden when loading from file
rail_generator=rail_from_file(file_name, load_from_package),
number_of_agents=1, # will be overridden when loading from file
schedule_generator=schedule_from_file(file_name, load_from_package),
environment = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name, load_from_package),
schedule_generator=schedule_from_file(file_name, load_from_package), number_of_agents=1,
obs_builder_object=obs_builder_object)
return environment
"""Malfunction generators for rail systems"""
from typing import Tuple, Callable
from typing import Callable, NamedTuple, Optional, Tuple
import msgpack
import numpy as np
from numpy.random.mtrand import RandomState
MalfunctionGenerator = Callable[[], Tuple[float, int, int]]
from envs.agent_utils import EnvAgent
Malfunction = NamedTuple('Malfunction', [('num_broken_steps', int)])
MalfunctionGenerator = Callable[[EnvAgent], Optional[Malfunction]]
MalfunctionProcessData = NamedTuple('MalfunctionProcessData',
[('malfunction_rate', float), ('min_duration', int), ('max_duration', int)])
def malfunction_from_file(filename) -> MalfunctionGenerator:
def _malfunction_prob(rate: float) -> float:
"""
Probability of a single agent to break. According to Poisson process with given rate
:param rate:
:return:
"""
if rate <= 0:
return 0.
else:
return 1 - np.exp(- (1 / rate))
def malfunction_from_file(filename) -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
"""
Utility to load pickle file
......@@ -19,25 +38,48 @@ def malfunction_from_file(filename) -> MalfunctionGenerator:
-------
Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
"""
def generator():
with open(filename, "rb") as file_in:
load_data = file_in.read()
data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8')
if "malfunction" in data:
# Mean malfunction in number of time steps
mean_malfunction_rate = data["malfunction"]["malfunction_rate"]
# Uniform distribution parameters for malfunction duration
min_number_of_steps_broken = data["malfunction"]["min_duration"]
max_number_of_steps_broken = data["malfunction"]["max_duration"]
agents_speed = None
return mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
return generator
def malfunction_from_params(parameters) -> MalfunctionGenerator:
with open(filename, "rb") as file_in:
load_data = file_in.read()
data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8')
if "malfunction" in data:
# Mean malfunction in number of time steps
mean_malfunction_rate = data["malfunction"]["malfunction_rate"]
# Uniform distribution parameters for malfunction duration
min_number_of_steps_broken = data["malfunction"]["min_duration"]
max_number_of_steps_broken = data["malfunction"]["max_duration"]
else:
# Mean malfunction in number of time steps
mean_malfunction_rate = 0.
# Uniform distribution parameters for malfunction duration
min_number_of_steps_broken = 0
max_number_of_steps_broken = 0
def generator(agent: EnvAgent, np_random: RandomState) -> Optional[Malfunction]:
"""
Generate malfunctions for agents
Parameters
----------
agent
np_random
Returns
-------
int: Number of time steps an agent is broken
"""
if agent.malfunction_data['malfunction'] < 1:
if np_random.rand() < _malfunction_prob(mean_malfunction_rate):
num_broken_steps = np_random.randint(min_number_of_steps_broken,
max_number_of_steps_broken + 1) + 1
return Malfunction(num_broken_steps)
return Malfunction(0)
return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
max_number_of_steps_broken)
def malfunction_from_params(parameters: dict) -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
"""
Utility to load malfunction from parameters
......@@ -52,17 +94,34 @@ def malfunction_from_params(parameters) -> MalfunctionGenerator:
-------
Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
"""
def generator():
mean_malfunction_rate = parameters['malfunction_rate']
min_number_of_steps_broken = parameters['min_duration']
max_number_of_steps_broken = parameters['max_duration']
return mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
return generator
def no_malfunction_generator() -> MalfunctionGenerator:
mean_malfunction_rate = parameters['malfunction_rate']
min_number_of_steps_broken = parameters['min_duration']
max_number_of_steps_broken = parameters['max_duration']
def generator(agent: EnvAgent, np_random: RandomState) -> Optional[Malfunction]:
"""
Generate malfunctions for agents
Parameters
----------
agent
np_random
Returns
-------
int: Number of time steps an agent is broken
"""
if agent.malfunction_data['malfunction'] < 1:
if np_random.rand() < _malfunction_prob(mean_malfunction_rate):
num_broken_steps = np_random.randint(min_number_of_steps_broken,
max_number_of_steps_broken + 1) + 1
return Malfunction(num_broken_steps)
return Malfunction(0)
return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
max_number_of_steps_broken)
def no_malfunction_generator() -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
"""
Utility to load malfunction from parameters
......@@ -74,8 +133,15 @@ def no_malfunction_generator() -> MalfunctionGenerator:
-------
Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
"""
# Mean malfunction in number of time steps
mean_malfunction_rate = 0.
# Uniform distribution parameters for malfunction duration
min_number_of_steps_broken = 0
max_number_of_steps_broken = 0
def generator():
return 0, 0, 0
def generator(agent: EnvAgent, np_random: RandomState) -> Optional[Malfunction]:
return Malfunction(0)
return generator
return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
max_number_of_steps_broken)
......@@ -19,7 +19,7 @@ from flatland.core.grid.grid_utils import IntVector2D
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus
from flatland.envs.distance_map import DistanceMap
from flatland.envs.malfunction_generators import MalfunctionGenerator, no_malfunction_generator
from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_generators import random_rail_generator, RailGenerator
from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
......@@ -119,7 +119,7 @@ class RailEnv(Environment):
schedule_generator: ScheduleGenerator = random_schedule_generator(),
number_of_agents=1,
obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(),
malfunction_generator: MalfunctionGenerator = no_malfunction_generator(),
malfunction_generator_and_process_data=no_malfunction_generator(),
remove_agents_at_target=True,
random_seed=1):
"""
......@@ -159,9 +159,9 @@ class RailEnv(Environment):
"""
super().__init__()
self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
self.rail_generator: RailGenerator = rail_generator
self.schedule_generator: ScheduleGenerator = schedule_generator
self.malfunction_generator: MalfunctionGenerator = malfunction_generator
self.rail: Optional[GridTransitionMap] = None
self.width = width
self.height = height
......@@ -196,14 +196,6 @@ class RailEnv(Environment):
if self.random_seed:
self._seed(seed=random_seed)
# Stochastic train malfunctioning parameters
mean_malfunction_rate, malfunction_min_duration, malfunction_max_duration = self.malfunction_generator()
self.mean_malfunction_rate = mean_malfunction_rate
# Uniform distribution parameters for malfunction duration
self.min_number_of_steps_broken = malfunction_min_duration
self.max_number_of_steps_broken = malfunction_max_duration
self.valid_positions = None
# global numpy array of agents position, True means that there is an agent at that cell
......@@ -350,12 +342,6 @@ class RailEnv(Environment):
else:
self._max_episode_steps = self.compute_max_episode_steps(width=self.width, height=self.height)
# Stochastic train malfunctioning parameters
mean_malfunction_rate, malfunction_min_duration, malfunction_max_duration = self.malfunction_generator()
self.mean_malfunction_rate = mean_malfunction_rate
self.min_number_of_steps_broken = malfunction_min_duration
self.max_number_of_steps_broken = malfunction_max_duration
self.agent_positions = np.zeros((self.height, self.width), dtype=int) - 1
self.restart_agents()
......@@ -365,7 +351,7 @@ class RailEnv(Environment):
if activate_agents:
self.set_agent_active(agent)
self._break_agent(self.mean_malfunction_rate, agent)
self._break_agent(agent)
if agent.malfunction_data["malfunction"] > 0:
agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING
......@@ -419,7 +405,7 @@ class RailEnv(Environment):
agent.moving = agent.malfunction_data['moving_before_malfunction']
return
def _break_agent(self, rate: float, agent) -> bool:
def _break_agent(self, agent):
"""
Malfunction generator that breaks agents at a given rate.
......@@ -428,13 +414,13 @@ class RailEnv(Environment):
agent
"""
if agent.malfunction_data['malfunction'] < 1:
if self.np_random.rand() < self._malfunction_prob(rate):
num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken,
self.max_number_of_steps_broken + 1) + 1
agent.malfunction_data['malfunction'] = num_broken_steps
agent.malfunction_data['moving_before_malfunction'] = agent.moving
agent.malfunction_data['nr_malfunctions'] += 1
malfunction: Malfunction = self.malfunction_generator(agent, self.np_random)
if malfunction.num_broken_steps > 0:
agent.malfunction_data['malfunction'] = malfunction.num_broken_steps
agent.malfunction_data['moving_before_malfunction'] = agent.moving
agent.malfunction_data['nr_malfunctions'] += 1
return
def step(self, action_dict_: Dict[int, RailEnvActions]):
......@@ -481,7 +467,7 @@ class RailEnv(Environment):
self.rewards_dict[i_agent] = 0
# Induce malfunction before we do a step, thus a broken agent can't move in this step
self._break_agent(self.mean_malfunction_rate, agent)
self._break_agent(agent)
# Perform step on the agent
self._step_agent(i_agent, action_dict_.get(i_agent))
......@@ -816,9 +802,7 @@ class RailEnv(Environment):
grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static]
agent_data = [agent.to_list() for agent in self.agents]
malfunction_data = {"malfunction_rate": self.mean_malfunction_rate,
"min_duration": self.min_number_of_steps_broken,
"max_duration": self.max_number_of_steps_broken}
malfunction_data = {"malfunction_process_data": self.malfunction_process_data}
msgpack.packb(grid_data, use_bin_type=True)
msgpack.packb(agent_data, use_bin_type=True)
......@@ -841,9 +825,7 @@ class RailEnv(Environment):
msgpack.packb(agent_data, use_bin_type=True)
msgpack.packb(agent_static_data, use_bin_type=True)
distance_map_data = self.distance_map.get()
malfunction_data = {"malfunction_rate": self.mean_malfunction_rate,
"min_duration": self.min_number_of_steps_broken,
"max_duration": self.max_number_of_steps_broken}
malfunction_data = {"malfunction_process_data": self.malfunction_process_data}
msgpack.packb(distance_map_data, use_bin_type=True)
msg_data = {
"grid": grid_data,
......@@ -871,12 +853,6 @@ class RailEnv(Environment):
self.rail.height = self.height
self.rail.width = self.width
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
if "malfunction" in data:
# Mean malfunction in number of time steps
self.mean_malfunction_rate = data["malfunction"]["malfunction_rate"]
# Uniform distribution parameters for malfunction duration
self.min_number_of_steps_broken = data["malfunction"]["min_duration"]
self.max_number_of_steps_broken = data["malfunction"]["max_duration"]
def set_full_state_dist_msg(self, msg_data):
"""
......@@ -898,12 +874,6 @@ class RailEnv(Environment):
self.rail.height = self.height
self.rail.width = self.width
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
if "malfunction" in data:
# Mean malfunction in number of time steps
self.mean_malfunction_rate = data["malfunction"]["malfunction_rate"]
# Uniform distribution parameters for malfunction duration
self.min_number_of_steps_broken = data["malfunction"]["min_duration"]
self.max_number_of_steps_broken = data["malfunction"]["max_duration"]
def save(self, filename, save_distance_maps=False):
"""
......@@ -970,17 +940,6 @@ class RailEnv(Environment):
x = - np.log(1 - u) * rate
return x
def _malfunction_prob(self, rate: float) -> float:
"""
Probability of a single agent to break. According to Poisson process with given rate
:param rate:
:return:
"""
if rate <= 0:
return 0.
else:
return 1 - np.exp(- (1 / rate))
def _is_agent_ok(self, agent: EnvAgent) -> bool:
"""
Check if an agent is ok, meaning it can move and is not malfuncitoinig
......
......@@ -30,10 +30,7 @@ def load_flatland_environment_from_file(file_name: str,
obs_builder_object = TreeObsForRailEnv(
max_depth=2,
predictor=ShortestPathPredictorForRailEnv(max_depth=10))
environment = RailEnv(width=1, # will be overridden when loading from file
height=1, # will be overridden when loading from file
rail_generator=rail_from_file(file_name, load_from_package),
number_of_agents=1, # will be overridden when loading from file
schedule_generator=schedule_from_file(file_name, load_from_package),
environment = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name, load_from_package),
schedule_generator=schedule_from_file(file_name, load_from_package), number_of_agents=1,
obs_builder_object=obs_builder_object)
return environment
import os
import uuid
import subprocess
import glob
import os
import random
import subprocess
import uuid
###############################################################
# Expected Env Variables
......
......@@ -557,8 +557,7 @@ def test_sparse_rail_generator_deterministic():
seed=215545, # Random seed
grid_mode=True
),
schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=1,
malfunction_generator=malfunction_from_params(stochastic_data))
schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=1)
env.reset()
# for r in range(env.height):
# for c in range(env.width):
......@@ -1406,8 +1405,7 @@ def test_rail_env_malfunction_speed_info():
grid_mode=False
),
schedule_generator=sparse_schedule_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv(),
malfunction_generator=malfunction_from_params(stochastic_data))
obs_builder_object=GlobalObsForRailEnv())
env.reset(False, False, True)
env_renderer = RenderTool(env, gl="PILSVG", )
......
......@@ -73,10 +73,14 @@ def test_malfunction_process():
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(), number_of_agents=1,
obs_builder_object=SingleAgentNavigationObs(),
malfunction_generator=malfunction_from_params(stochastic_data))
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=SingleAgentNavigationObs()
)
# reset to initialize agents_static
obs, info = env.reset(False, False, True, random_seed=10)
......@@ -123,10 +127,14 @@ def test_malfunction_process_statistically():
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(), number_of_agents=10,
obs_builder_object=SingleAgentNavigationObs(),
malfunction_generator=malfunction_from_params(stochastic_data))
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=10,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=SingleAgentNavigationObs()
)
# reset to initialize agents_static
env.reset(True, True, False, random_seed=10)
......@@ -135,15 +143,15 @@ def test_malfunction_process_statistically():
# Next line only for test generation
# agent_malfunction_list = [[] for i in range(10)]
agent_malfunction_list = [[0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1],
[0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0],
[5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0],
[0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1],
[0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2],
[5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[5, 4, 3, 2, 1, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0]]
[0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0],
[5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0],
[0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1],
[0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2],
[5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[5, 4, 3, 2, 1, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0]]
for step in range(20):
action_dict: Dict[int, RailEnvActions] = {}
......@@ -166,9 +174,14 @@ def test_malfunction_before_entry():
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(seed=1), number_of_agents=10,
malfunction_generator=malfunction_from_params(stochastic_data), random_seed=1)
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=10,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=SingleAgentNavigationObs()
)
# reset to initialize agents_static
env.reset(False, False, False, random_seed=10)
env.agents[0].target = (0, 0)
......@@ -205,9 +218,14 @@ def test_malfunction_values_and_behavior():
stochastic_data = {'malfunction_rate': 0.001,
'min_duration': 10,
'max_duration': 10}
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(seed=2), number_of_agents=1,
malfunction_generator=malfunction_from_params(stochastic_data), random_seed=1)
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=SingleAgentNavigationObs()
)
# reset to initialize agents_static
env.reset(False, False, activate_agents=True, random_seed=10)
......@@ -232,8 +250,7 @@ def test_initial_malfunction():
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(seed=10), number_of_agents=1,
obs_builder_object=SingleAgentNavigationObs(),
malfunction_generator=malfunction_from_params(stochastic_data))
obs_builder_object=SingleAgentNavigationObs())
# reset to initialize agents_static
env.reset(False, False, True, random_seed=10)
print(env.agents[0].malfunction_data)
......@@ -300,8 +317,7 @@ def test_initial_malfunction_stop_moving():
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(), number_of_agents=1,
obs_builder_object=SingleAgentNavigationObs(),
malfunction_generator=malfunction_from_params(stochastic_data))
obs_builder_object=SingleAgentNavigationObs())
env.reset()
print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status)
......@@ -386,8 +402,7 @@ def test_initial_malfunction_do_nothing():
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(), number_of_agents=1,
malfunction_generator=malfunction_from_params(stochastic_data))
schedule_generator=random_schedule_generator(), number_of_agents=1)
# reset to initialize agents_static
env.reset()
set_penalties_for_replay(env)
......@@ -465,8 +480,7 @@ def tests_random_interference_from_outside():
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(seed=2), number_of_agents=1,
malfunction_generator=malfunction_from_params(stochastic_data), random_seed=1)
schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, random_seed=1)
env.reset()
# reset to initialize agents_static
env.agents[0].speed_data['speed'] = 0.33
......@@ -491,8 +505,7 @@ def tests_random_interference_from_outside():
random.seed(47)
np.random.seed(1234)
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(seed=2), number_of_agents=1,
malfunction_generator=malfunction_from_params(stochastic_data), random_seed=1)
schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, random_seed=1)
env.reset()
# reset to initialize agents_static
env.agents[0].speed_data['speed'] = 0.33
......@@ -528,8 +541,7 @@ def test_last_malfunction_step():
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(seed=2), number_of_agents=1,
malfunction_generator=malfunction_from_params(stochastic_data), random_seed=1)
schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, random_seed=1)
env.reset()
# reset to initialize agents_static
env.agents[0].speed_data['speed'] = 1. / 3.
......
......@@ -29,8 +29,7 @@ def test_get_global_observation():
grid_mode=False
),
schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=number_of_agents,
obs_builder_object=GlobalObsForRailEnv(),
malfunction_generator=malfunction_from_params(stochastic_data))
obs_builder_object=GlobalObsForRailEnv())