diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 5556a2a0ed9c5b67e2708c8bf222304603a131ad..3ee9ccf80d89b7fbda859197b489bc9918a43d96 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -2,6 +2,7 @@ import time import numpy as np +from flatland.envs.malfunction_generators import malfunction_from_params from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv @@ -38,7 +39,7 @@ env = RailEnv(width=100, height=100, rail_generator=sparse_rail_generator(max_nu max_rails_in_city=8, ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=100, - obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=stochastic_data, + obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=malfunction_from_params(stochastic_data), remove_agents_at_target=True) # RailEnv.DEPOT_POSITION = lambda agent, agent_handle : (agent_handle % env.height,0) diff --git a/examples/introduction_flatland_2_1.py b/examples/introduction_flatland_2_1.py index de7c77faebc1f9d8fb8bb9b18fec142601babcb7..dd17c285519cb35ee5d11a3ed1731f3e33a45c33 100644 --- a/examples/introduction_flatland_2_1.py +++ b/examples/introduction_flatland_2_1.py @@ -3,6 +3,7 @@ import numpy as np # In Flatland you can use custom observation builders and predicitors # Observation builders generate the observation needed by the controller # Preditctors can be used to do short time prediction which can help in avoiding conflicts in the network +from flatland.envs.malfunction_generators import malfunction_from_params from flatland.envs.observations import GlobalObsForRailEnv # First of all we import the Flatland rail environment from flatland.envs.rail_env import RailEnv @@ -73,7 +74,7 @@ observation_builder = GlobalObsForRailEnv() # Construct the enviornment with the given observation, generataors, predictors, and stochastic data env = RailEnv(width=width, height=height, rail_generator=rail_generator, schedule_generator=schedule_generator, - number_of_agents=nr_trains, obs_builder_object=observation_builder, malfunction_generator=stochastic_data, + number_of_agents=nr_trains, obs_builder_object=observation_builder, malfunction_generator=malfunction_from_params(stochastic_data), remove_agents_at_target=True) env.reset()