From e05151d5ff91a6cb2243c7f03ab2a6dd9b1f2feb Mon Sep 17 00:00:00 2001 From: Dipam Chakraborty <dipam@aicrowd.com> Date: Fri, 30 Jul 2021 21:40:32 +0530 Subject: [PATCH] tests for env step transitions WIP --- tests/test_env_step_utils.py | 61 +++++++++++++++++++ tests/test_state_machine.py | 115 +++++++++++++++++++++++++++++++++++ 2 files changed, 176 insertions(+) create mode 100644 tests/test_env_step_utils.py create mode 100644 tests/test_state_machine.py diff --git a/tests/test_env_step_utils.py b/tests/test_env_step_utils.py new file mode 100644 index 00000000..739d3d06 --- /dev/null +++ b/tests/test_env_step_utils.py @@ -0,0 +1,61 @@ +import numpy as np +import numpy as np +import os + +from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters, ParamMalfunctionGen + +from flatland.envs.observations import GlobalObsForRailEnv +# First of all we import the Flatland rail environment +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_env import RailEnvActions +from flatland.envs.rail_generators import sparse_rail_generator +#from flatland.envs.sparse_rail_gen import SparseRailGen +from flatland.envs.schedule_generators import sparse_schedule_generator + + +def get_small_two_agent_env(): + """Generates a simple 2 city 2 train env returns it after reset""" + width = 30 # With of map + height = 15 # Height of map + nr_trains = 2 # Number of trains that have an assigned task in the env + cities_in_map = 2 # Number of cities where agents can start or end + seed = 42 # Random seed + grid_distribution_of_cities = False # Type of city distribution, if False cities are randomly placed + max_rails_between_cities = 2 # Max number of tracks allowed between cities. This is number of entry point to a city + max_rail_in_cities = 6 # Max number of parallel tracks within a city, representing a realistic trainstation + + rail_generator = sparse_rail_generator(max_num_cities=cities_in_map, + seed=seed, + grid_mode=grid_distribution_of_cities, + max_rails_between_cities=max_rails_between_cities, + max_rail_pairs_in_city=max_rail_in_cities//2, + ) + speed_ration_map = {1.: 0.25, # Fast passenger train + 1. / 2.: 0.25, # Fast freight train + 1. / 3.: 0.25, # Slow commuter train + 1. / 4.: 0.25} # Slow freight train + + schedule_generator = sparse_schedule_generator(speed_ration_map) + + + stochastic_data = MalfunctionParameters(malfunction_rate=1/10000, # Rate of malfunction occurence + min_duration=15, # Minimal duration of malfunction + max_duration=50 # Max duration of malfunction + ) + + observation_builder = GlobalObsForRailEnv() + + env = RailEnv(width=width, + height=height, + rail_generator=rail_generator, + schedule_generator=schedule_generator, + number_of_agents=nr_trains, + obs_builder_object=observation_builder, + #malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), + malfunction_generator=ParamMalfunctionGen(stochastic_data), + remove_agents_at_target=True, + random_seed=seed) + + env.reset() + + return env \ No newline at end of file diff --git a/tests/test_state_machine.py b/tests/test_state_machine.py new file mode 100644 index 00000000..266a8f86 --- /dev/null +++ b/tests/test_state_machine.py @@ -0,0 +1,115 @@ +from test_env_step_utils import get_small_two_agent_env +from flatland.envs.rail_env_action import RailEnvActions +from flatland.envs.step_utils.states import TrainState +from flatland.envs.malfunction_generators import Malfunction + +class NoMalfunctionGenerator: + def generate(self, np_random): + return Malfunction(0) + +class AlwaysThreeStepMalfunction: + def generate(self, np_random): + return Malfunction(3) + +def test_waiting_no_transition(): + env = get_small_two_agent_env() + env.malfunction_generator = NoMalfunctionGenerator() + i_agent = 0 + ed = env.agents[i_agent].earliest_departure + for _ in range(ed-1): + env.step({i_agent: RailEnvActions.MOVE_FORWARD}) + assert env.agents[i_agent].state == TrainState.WAITING + + +def test_waiting_to_ready_to_depart(): + env = get_small_two_agent_env() + env.malfunction_generator = NoMalfunctionGenerator() + i_agent = 0 + ed = env.agents[i_agent].earliest_departure + for _ in range(ed): + env.step({i_agent: RailEnvActions.DO_NOTHING}) + assert env.agents[i_agent].state == TrainState.READY_TO_DEPART + + +def test_ready_to_depart_to_moving(): + env = get_small_two_agent_env() + env.malfunction_generator = NoMalfunctionGenerator() + i_agent = 0 + ed = env.agents[i_agent].earliest_departure + for _ in range(ed): + env.step({i_agent: RailEnvActions.DO_NOTHING}) + + env.step({i_agent: RailEnvActions.MOVE_FORWARD}) + assert env.agents[i_agent].state == TrainState.MOVING + +def test_moving_to_stopped(): + env = get_small_two_agent_env() + env.malfunction_generator = NoMalfunctionGenerator() + i_agent = 0 + ed = env.agents[i_agent].earliest_departure + for _ in range(ed): + env.step({i_agent: RailEnvActions.DO_NOTHING}) + + env.step({i_agent: RailEnvActions.MOVE_FORWARD}) + env.step({i_agent: RailEnvActions.STOP_MOVING}) + assert env.agents[i_agent].state == TrainState.STOPPED + +def test_stopped_to_moving(): + env = get_small_two_agent_env() + env.malfunction_generator = NoMalfunctionGenerator() + i_agent = 0 + ed = env.agents[i_agent].earliest_departure + for _ in range(ed): + env.step({i_agent: RailEnvActions.DO_NOTHING}) + + env.step({i_agent: RailEnvActions.MOVE_FORWARD}) + env.step({i_agent: RailEnvActions.STOP_MOVING}) + env.step({i_agent: RailEnvActions.MOVE_FORWARD}) + assert env.agents[i_agent].state == TrainState.MOVING + +def test_moving_to_done(): + env = get_small_two_agent_env() + env.malfunction_generator = NoMalfunctionGenerator() + i_agent = 1 + ed = env.agents[i_agent].earliest_departure + for _ in range(ed): + env.step({i_agent: RailEnvActions.DO_NOTHING}) + + for _ in range(50): + env.step({i_agent: RailEnvActions.MOVE_FORWARD}) + assert env.agents[i_agent].state == TrainState.DONE + +def test_waiting_to_malfunction(): + env = get_small_two_agent_env() + env.malfunction_generator = AlwaysThreeStepMalfunction() + i_agent = 1 + env.step({i_agent: RailEnvActions.DO_NOTHING}) + assert env.agents[i_agent].state == TrainState.MALFUNCTION_OFF_MAP + + +def test_ready_to_depart_to_malfunction_off_map(): + env = get_small_two_agent_env() + env.malfunction_generator = NoMalfunctionGenerator() + i_agent = 1 + env.step({i_agent: RailEnvActions.DO_NOTHING}) + ed = env.agents[i_agent].earliest_departure + for _ in range(ed): + env.step({i_agent: RailEnvActions.DO_NOTHING}) # This should get into ready to depart + + env.malfunction_generator = AlwaysThreeStepMalfunction() + env.step({i_agent: RailEnvActions.DO_NOTHING}) + assert env.agents[i_agent].state == TrainState.MALFUNCTION_OFF_MAP + + +def test_malfunction_off_map_to_waiting(): + env = get_small_two_agent_env() + env.malfunction_generator = NoMalfunctionGenerator() + i_agent = 1 + env.step({i_agent: RailEnvActions.DO_NOTHING}) + ed = env.agents[i_agent].earliest_departure + for _ in range(ed): + env.step({i_agent: RailEnvActions.DO_NOTHING}) # This should get into ready to depart + + env.malfunction_generator = AlwaysThreeStepMalfunction() + env.step({i_agent: RailEnvActions.DO_NOTHING}) + assert env.agents[i_agent].state == TrainState.MALFUNCTION_OFF_MAP \ No newline at end of file -- GitLab