From b36df2897785d4d83171c270eb8e5d2379fe0828 Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Thu, 10 Oct 2019 17:06:11 -0400 Subject: [PATCH] updated tests --- flatland/utils/simple_rail.py | 2 +- tests/test_flatland_malfunction.py | 74 ++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/flatland/utils/simple_rail.py b/flatland/utils/simple_rail.py index 019c9b66..c9558330 100644 --- a/flatland/utils/simple_rail.py +++ b/flatland/utils/simple_rail.py @@ -12,7 +12,7 @@ def make_simple_rail() -> Tuple[GridTransitionMap, np.array]: # | # | # | - # _ _ _ _\ _ _ _ _ _ _ + # _ _ _ _\ _ _ _ _ _ _ # / # | # | diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 8bbfefbb..8fdd907c 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -512,3 +512,77 @@ def test_initial_nextmalfunction_not_below_zero(): # was next_malfunction was -1 befor the bugfix https://gitlab.aicrowd.com/flatland/flatland/issues/186 assert agent.malfunction_data['next_malfunction'] >= 0, \ "next_malfunction should be >=0, found {}".format(agent.malfunction_data['next_malfunction']) + + +def tests_random_interference_from_outside(): + """Tests that malfunctions are produced by stochastic_data!""" + # Set fixed malfunction duration for this test + stochastic_data = {'prop_malfunction': 1., + 'malfunction_rate': 1, + 'min_duration': 10, + 'max_duration': 10} + + rail, rail_map = make_simple_rail2() + + env = RailEnv(width=25, + height=30, + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(seed=2), # seed 12 + number_of_agents=1, + random_seed=1, + stochastic_data=stochastic_data, # Malfunction data generator + ) + # reset to initialize agents_static + env.agents[0].speed_data['speed'] = 0.33 + env.agents[0].initial_position = (3, 0) + env.agents[0].target = (3, 9) + env.reset(False, False, False) + env_data = [] + + for step in range(200): + action_dict: Dict[int, RailEnvActions] = {} + for agent in env.agents: + # We randomly select an action + action_dict[agent.handle] = RailEnvActions(2) + + _, reward, _, _ = env.step(action_dict) + # Append the rewards of the first trial + env_data.append((reward[0],env.agents[0].position)) + assert reward[0] == env_data[step][0] + assert env.agents[0].position == env_data[step][1] + # Run the same test as above but with an external random generator running + # Check that the reward stays the same + + rail, rail_map = make_simple_rail2() + random.seed(47) + np.random.seed(1234) + env = RailEnv(width=25, + height=30, + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(seed=2), # seed 12 + number_of_agents=1, + random_seed=1, + stochastic_data=stochastic_data, # Malfunction data generator + ) + # reset to initialize agents_static + env.agents[0].speed_data['speed'] = 0.33 + env.agents[0].initial_position = (3, 0) + env.agents[0].target = (3, 9) + env.reset(False, False, False) + + + # Print for test generation + dummy_list = [1, 2, 6, 7, 8, 9, 4, 5, 4] + for step in range(200): + action_dict: Dict[int, RailEnvActions] = {} + for agent in env.agents: + # We randomly select an action + action_dict[agent.handle] = RailEnvActions(2) + + # Do dummy random number generations + a = random.shuffle(dummy_list) + b = np.random.rand() + + _, reward, _, _ = env.step(action_dict) + assert reward[0] == env_data[step][0] + assert env.agents[0].position == env_data[step][1] -- GitLab