Skip to content
Snippets Groups Projects
Commit c7f10073 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Introducing malfunction_generators

This resolves issue #273

fixed tests
parent 117aa7d1
No related branches found
No related tags found
No related merge requests found
import numpy as np
from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
from flatland.envs.malfunction_generators import malfunction_from_params
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import sparse_rail_generator
......@@ -28,7 +29,7 @@ def test_get_global_observation():
grid_mode=False
),
schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=number_of_agents,
obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=stochastic_data)
obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=malfunction_from_params(stochastic_data))
env.reset()
obs, all_rewards, done, _ = env.step({i: RailEnvActions.MOVE_FORWARD for i in range(number_of_agents)})
......
import numpy as np
from flatland.envs.malfunction_generators import malfunction_from_params
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
......@@ -108,12 +109,12 @@ def test_seeding_and_malfunction():
for tests in range(1, 100):
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=stochastic_data)
obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=malfunction_from_params(stochastic_data))
# Tree Observation
env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=stochastic_data)
obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=malfunction_from_params(stochastic_data))
env.reset(True, False, True, random_seed=tests)
env2.reset(True, False, True, random_seed=tests)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment