diff --git a/README.md b/README.md index 14102defb5f7508d9ceb04e3cd0f1a441c4ca172..0f621c1b0c29f3dfce2c0e6b892106dbc9ff6919 100644 --- a/README.md +++ b/README.md @@ -106,7 +106,7 @@ env = RailEnv( nr_extra=1, min_dist=8, max_dist=99999, - seed=0), + seed=1), schedule_generator=complex_schedule_generator(), number_of_agents=NUMBER_OF_AGENTS) diff --git a/docs/specifications/railway.md b/docs/specifications/railway.md index 04867f08f4b3a2948dd09983d552d5f33222cf4f..ec707f871b31a130f7a129e77042d02c31ddd300 100644 --- a/docs/specifications/railway.md +++ b/docs/specifications/railway.md @@ -380,7 +380,7 @@ ScheduleGenerator = Callable[[GridTransitionMap, int, Optional[Any]], ScheduleGe We can then produce `RailGenerator`s by currying: ```python def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2, min_node_dist=20, node_radius=2, - num_neighb=3, grid_mode=False, enhance_intersection=False, seed=0): + num_neighb=3, grid_mode=False, enhance_intersection=False, seed=1): def generator(width, height, num_agents, num_resets=0): diff --git a/docs/tutorials/01_gettingstarted.rst b/docs/tutorials/01_gettingstarted.rst index 2be0fee229c93f2aee215cf1e459dd8fc8f92ebb..e3a2f41aca4714150e96b725d1b01da0f5e20958 100644 --- a/docs/tutorials/01_gettingstarted.rst +++ b/docs/tutorials/01_gettingstarted.rst @@ -145,7 +145,7 @@ Next we configure the difficulty of our task by modifying the complex_rail_gener nr_extra=10, min_dist=10, max_dist=99999, - seed=0), + seed=1), number_of_agents=5) The difficulty of a railway network depends on the dimensions (`width` x `height`) and the number of agents in the network. diff --git a/docs/tutorials/02_observationbuilder.rst b/docs/tutorials/02_observationbuilder.rst index f6e718ab156a3972a57f4693367d5f191c8fcdd0..8afcf71043cdff6239c8944ce243d1f1407fc718 100644 --- a/docs/tutorials/02_observationbuilder.rst +++ b/docs/tutorials/02_observationbuilder.rst @@ -118,7 +118,7 @@ Note that this simple strategy fails when multiple agents are present, as each a env = RailEnv(width=7, height=7, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, \ - min_dist=8, max_dist=99999, seed=0), + min_dist=8, max_dist=99999, seed=1), number_of_agents=2, obs_builder_object=SingleAgentNavigationObs()) @@ -267,7 +267,7 @@ We can then use this new observation builder and the renderer to visualize the o # Initiate Environment env = RailEnv(width=10, height=10, - rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, seed=0), + rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, seed=1), number_of_agents=3, obs_builder_object=CustomObsBuilder) diff --git a/docs/tutorials/05_multispeed.md b/docs/tutorials/05_multispeed.md index 99db7ee6a275fe317c891cdebfc8d6aaf99d0b8c..118d4c5957a9b0219cc4ec1f897e2a322d8a248e 100644 --- a/docs/tutorials/05_multispeed.md +++ b/docs/tutorials/05_multispeed.md @@ -64,7 +64,7 @@ ScheduleGenerator = Callable[[GridTransitionMap, int, Optional[Any]], ScheduleGe We can then produce `RailGenerator`s by currying: ```python def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2, min_node_dist=20, node_radius=2, - num_neighb=3, grid_mode=False, enhance_intersection=False, seed=0): + num_neighb=3, grid_mode=False, enhance_intersection=False, seed=1): def generator(width, height, num_agents, num_resets=0): diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py index 9124edd1cae6e7d1f2f3b909c721a8f4ffd35ca3..d4f013d56e0cbac5462cef3fab57ad12450ab0ad 100644 --- a/examples/Simple_Realistic_Railway_Generator.py +++ b/examples/Simple_Realistic_Railway_Generator.py @@ -201,7 +201,7 @@ def realistic_rail_generator(max_num_cities=5, nbr_of_switches_per_station_track=2, connect_max_nbr_of_shortes_city=4, do_random_connect_stations=False, - seed=0, + seed=1, print_out_info=True) -> RailGenerator: """ This is a level generator which generates a realistic rail configurations diff --git a/examples/custom_observation_example_02_SingleAgentNavigationObs.py b/examples/custom_observation_example_02_SingleAgentNavigationObs.py index 3c10e415155429aa1464c554cdeb9ee1309780f2..aa9da8494b4a7376a3468c16618e5459344a37f3 100644 --- a/examples/custom_observation_example_02_SingleAgentNavigationObs.py +++ b/examples/custom_observation_example_02_SingleAgentNavigationObs.py @@ -75,7 +75,7 @@ def main(args): env = RailEnv(width=7, height=7, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, - seed=0), + seed=1), schedule_generator=complex_schedule_generator(), number_of_agents=1, obs_builder_object=SingleAgentNavigationObs()) diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py index 97fac09c8d222b11cc8e8cc3ab4f76c75a37a95b..7af7499af8ffc5900a87d7e543310aab2a6df7f9 100644 --- a/examples/custom_observation_example_03_ObservePredictions.py +++ b/examples/custom_observation_example_03_ObservePredictions.py @@ -125,7 +125,7 @@ def main(args): env = RailEnv(width=10, height=10, rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, - seed=0), + seed=1), schedule_generator=complex_schedule_generator(), number_of_agents=3, obs_builder_object=custom_obs_builder) diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py index 1f094c09b6a929b7e862bc42afffd1da38e52e73..9de330c9bd17f8c07acb931b1ca4eff689c798a7 100644 --- a/examples/debugging_example_DELETE.py +++ b/examples/debugging_example_DELETE.py @@ -57,7 +57,7 @@ class SingleAgentNavigationObs(ObservationBuilder): env = RailEnv(width=14, height=14, - rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0), + rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=1), schedule_generator=complex_schedule_generator(), number_of_agents=2, obs_builder_object=SingleAgentNavigationObs()) diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py index 61b83c3b9ec2d5f63203897aa52c87d27c4460bd..f294279d33b4f5ac7bcdd67202d1b09e65fcab29 100644 --- a/examples/simple_example_3.py +++ b/examples/simple_example_3.py @@ -13,7 +13,7 @@ np.random.seed(1) env = RailEnv(width=7, height=7, - rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0), + rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=1), schedule_generator=complex_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2)) diff --git a/examples/training_example.py b/examples/training_example.py index 8b42586b06d90bef44fe94ca6817fb0a84824aac..9d6631719db74ceaa8ceef9352345b2ddbeb99fd 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -16,7 +16,7 @@ TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictor LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2) env = RailEnv(width=20, height=20, - rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=2, min_dist=8, max_dist=99999, seed=0), + rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=2, min_dist=8, max_dist=99999, seed=1), schedule_generator=complex_schedule_generator(), obs_builder_object=TreeObservation, number_of_agents=3) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index aa8e48023694c96941b59d6f78b3fd93edf81e9e..ae3169f298ddd686d834d4a8bd571f4403dbbb79 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -118,7 +118,7 @@ class RailEnv(Environment): max_episode_steps=None, stochastic_data=None, remove_agents_at_target=False, - random_seed=None + random_seed=1 ): """ Environment init. @@ -160,7 +160,6 @@ class RailEnv(Environment): self.rail_generator: RailGenerator = rail_generator self.schedule_generator: ScheduleGenerator = schedule_generator - self.rail_generator = rail_generator self.rail: Optional[GridTransitionMap] = None self.width = width self.height = height @@ -188,7 +187,7 @@ class RailEnv(Environment): self.distance_map = DistanceMap(self.agents, self.height, self.width) self.action_space = [1] - + self._seed() self._seed() @@ -217,7 +216,7 @@ class RailEnv(Environment): # Uniform distribution parameters for malfunction duration self.min_number_of_steps_broken = malfunction_min_duration self.max_number_of_steps_broken = malfunction_max_duration - # Rest environment + # Reset environment self.reset() self.num_resets = 0 # yes, set it to zero again! @@ -354,18 +353,18 @@ class RailEnv(Environment): # If counter has come to zero --> Agent has malfunction # set next malfunction time and duration of current malfunction if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data['malfunction'] and \ - agent.malfunction_data['next_malfunction'] <= 0: + agent.malfunction_data['next_malfunction'] <= 0: # Increase number of malfunctions agent.malfunction_data['nr_malfunctions'] += 1 # Next malfunction in number of stops next_breakdown = int( - self.np_random.exponential(scale=agent.malfunction_data['malfunction_rate'])) + self._exp_distirbution_synced(rate=agent.malfunction_data['malfunction_rate'])) agent.malfunction_data['next_malfunction'] = next_breakdown # Duration of current malfunction num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken, - self.max_number_of_steps_broken + 1) + 1 + self.max_number_of_steps_broken + 1) + 1 agent.malfunction_data['malfunction'] = num_broken_steps agent.malfunction_data['moving_before_malfunction'] = agent.moving @@ -500,7 +499,7 @@ class RailEnv(Environment): self.rewards_dict[i_agent] += self.stop_penalty if not agent.moving and not ( - action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING): + action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING): # Allow agent to start with any forward or direction action agent.moving = True self.rewards_dict[i_agent] += self.start_penalty @@ -726,7 +725,7 @@ class RailEnv(Environment): return msgpack.packb(msg_data, use_bin_type=True) - def save(self, filename,save_distance_maps=False): + def save(self, filename, save_distance_maps=False): if save_distance_maps == True: if self.distance_map.get() is not None: if len(self.distance_map.get()) > 0: @@ -754,3 +753,14 @@ class RailEnv(Environment): from importlib_resources import read_binary load_data = read_binary(package, resource) self.set_full_state_msg(load_data) + + def _exp_distirbution_synced(self, rate): + """ + Generates sample from exponential distribution + We need this to guarantee synchronity between different instances with same seed. + :param rate: + :return: + """ + u = self.np_random.rand() + x = - np.log(1 - u) * rate + return x diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 242e11d73797b2dfdcea0efe7a079dcc066ff05e..ced33635a68948f5e91cf2147972caafc7d76334 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -40,7 +40,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=99999, - seed=0) -> RailGenerator: + seed=1) -> RailGenerator: """ complex_rail_generator @@ -272,7 +272,7 @@ def rail_from_grid_transition_map(rail_map) -> RailGenerator: return generator -def random_rail_generator(cell_type_relative_proportion=[1.0] * 11, seed=0) -> RailGenerator: +def random_rail_generator(cell_type_relative_proportion=[1.0] * 11, seed=1) -> RailGenerator: """ Dummy random level generator: - fill in cells at random in [width-2, height-2] diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index e0281bb0d3f21a8e25fdff86ade5bed05f5dea13..7c8a65fd6501689a0911872a00150a87a082805d 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -28,7 +28,7 @@ def test_load_env(): def test_save_load(): env = RailEnv(width=10, height=10, - rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=0), + rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1), schedule_generator=complex_schedule_generator(), number_of_agents=2) env.reset() diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 73c831a426d55abf45217d66cba57481f6cc63ea..fe41c2b1461fcb86554ee7b1cefbd34ce408e8dd 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -170,7 +170,7 @@ def test_malfunction_before_entry(): """Tests that malfunctions are produced by stochastic_data!""" # Set fixed malfunction duration for this test stochastic_data = {'prop_malfunction': 1., - 'malfunction_rate': 2, + 'malfunction_rate': 1, 'min_duration': 10, 'max_duration': 10} @@ -187,15 +187,17 @@ def test_malfunction_before_entry(): # reset to initialize agents_static env.reset(False, False, False, random_seed=10) env.agents[0].target = (0, 0) + assert env.agents[1].malfunction_data['malfunction'] == 11 assert env.agents[2].malfunction_data['malfunction'] == 11 assert env.agents[3].malfunction_data['malfunction'] == 11 assert env.agents[4].malfunction_data['malfunction'] == 11 - assert env.agents[5].malfunction_data['malfunction'] == 0 + assert env.agents[5].malfunction_data['malfunction'] == 11 assert env.agents[6].malfunction_data['malfunction'] == 11 assert env.agents[7].malfunction_data['malfunction'] == 11 assert env.agents[8].malfunction_data['malfunction'] == 11 - assert env.agents[9].malfunction_data['malfunction'] == 0 + assert env.agents[9].malfunction_data['malfunction'] == 11 + for step in range(20): action_dict: Dict[int, RailEnvActions] = {} @@ -206,16 +208,16 @@ def test_malfunction_before_entry(): action_dict[agent.handle] = RailEnvActions(0) env.step(action_dict) + assert env.agents[1].malfunction_data['malfunction'] == 1 assert env.agents[2].malfunction_data['malfunction'] == 1 assert env.agents[3].malfunction_data['malfunction'] == 1 assert env.agents[4].malfunction_data['malfunction'] == 1 - assert env.agents[5].malfunction_data['malfunction'] == 2 + assert env.agents[5].malfunction_data['malfunction'] == 1 assert env.agents[6].malfunction_data['malfunction'] == 1 assert env.agents[7].malfunction_data['malfunction'] == 1 assert env.agents[8].malfunction_data['malfunction'] == 1 - assert env.agents[9].malfunction_data['malfunction'] == 3 - + assert env.agents[9].malfunction_data['malfunction'] == 1 # Print for test generation # for a in range(env.get_num_agents()): # print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a, @@ -226,7 +228,7 @@ def test_malfunction_before_entry(): def test_initial_malfunction(): stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents - 'malfunction_rate': 70, # Rate of malfunction occurence + 'malfunction_rate': 100, # Rate of malfunction occurence 'min_duration': 2, # Minimal duration of malfunction 'max_duration': 5 # Max duration of malfunction } @@ -236,7 +238,7 @@ def test_initial_malfunction(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), + schedule_generator=random_schedule_generator(seed=10), number_of_agents=1, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=SingleAgentNavigationObs() @@ -244,6 +246,7 @@ def test_initial_malfunction(): # reset to initialize agents_static env.reset(False, False, True, random_seed=10) + print(env.agents[0].malfunction_data) env.agents[0].target = (0, 5) set_penalties_for_replay(env) replay_config = ReplayConfig( diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index 3cd0a4c1d9812c728089255eb1150111b783a463..d190a5769415c65948b2d0e12c7f9b86b044a494 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -1,4 +1,5 @@ import numpy as np +from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.observations import TreeObsForRailEnv @@ -7,7 +8,6 @@ from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import complex_rail_generator, rail_from_grid_transition_map from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator from flatland.utils.simple_rail import make_simple_rail -from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay np.random.seed(1) @@ -52,7 +52,7 @@ def test_multi_speed_init(): env = RailEnv(width=50, height=50, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, - seed=0), + seed=1), schedule_generator=complex_schedule_generator(), number_of_agents=5) # Initialize the agent with the parameters corresponding to the environment and observation_builder diff --git a/tests/test_random_seeding.py b/tests/test_random_seeding.py index 3a2fe8bc1775653e15842ac27c262aab3e2b9ded..3a03de00c27795ea4b4a3915a23fec3a48732bbd 100644 --- a/tests/test_random_seeding.py +++ b/tests/test_random_seeding.py @@ -1,3 +1,7 @@ +import numpy as np + +from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.schedule_generators import random_schedule_generator @@ -32,4 +36,136 @@ def test_random_seeding(): # Test generation print assert env.agents[0].position == (3, 6) # print("env.agents[0].initial_position == {}".format(env.agents[0].initial_position)) - #print("assert env.agents[0].position == {}".format(env.agents[0].position)) + # print("assert env.agents[0].position == {}".format(env.agents[0].position)) + + +def test_seeding_and_observations(): + # Test if two different instances diverge with different observations + rail, rail_map = make_simple_rail2() + + # Make two seperate envs with different observation builders + # Global Observation + env = RailEnv(width=25, + height=30, + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(seed=12), + number_of_agents=10, + obs_builder_object=GlobalObsForRailEnv() + ) + # Tree Observation + env2 = RailEnv(width=25, + height=30, + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(seed=12), + number_of_agents=10, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) + ) + + env.reset(False, False, False, random_seed=12) + env2.reset(False, False, False, random_seed=12) + + # Check that both environments produce the same initial start positions + assert env.agents[0].initial_position == env2.agents[0].initial_position + assert env.agents[1].initial_position == env2.agents[1].initial_position + assert env.agents[2].initial_position == env2.agents[2].initial_position + assert env.agents[3].initial_position == env2.agents[3].initial_position + assert env.agents[4].initial_position == env2.agents[4].initial_position + assert env.agents[5].initial_position == env2.agents[5].initial_position + assert env.agents[6].initial_position == env2.agents[6].initial_position + assert env.agents[7].initial_position == env2.agents[7].initial_position + assert env.agents[8].initial_position == env2.agents[8].initial_position + assert env.agents[9].initial_position == env2.agents[9].initial_position + + action_dict = {} + for step in range(10): + for a in range(env.get_num_agents()): + action = np.random.randint(4) + action_dict[a] = action + env.step(action_dict) + env2.step(action_dict) + + # Check that both environments end up in the same position + + assert env.agents[0].position == env2.agents[0].position + assert env.agents[1].position == env2.agents[1].position + assert env.agents[2].position == env2.agents[2].position + assert env.agents[3].position == env2.agents[3].position + assert env.agents[4].position == env2.agents[4].position + assert env.agents[5].position == env2.agents[5].position + assert env.agents[6].position == env2.agents[6].position + assert env.agents[7].position == env2.agents[7].position + assert env.agents[8].position == env2.agents[8].position + assert env.agents[9].position == env2.agents[9].position + for a in range(env.get_num_agents()): + print("assert env.agents[{}].position == env2.agents[{}].position".format(a, a)) + + +def test_seeding_and_malfunction(): + # Test if two different instances diverge with different observations + rail, rail_map = make_simple_rail2() + + stochastic_data = {'prop_malfunction': 1., + 'malfunction_rate': 2, + 'min_duration': 10, + 'max_duration': 10} + # Make two seperate envs with different and see if the exhibit the same malfunctions + # Global Observation + for tests in range(1, 100): + env = RailEnv(width=25, + height=30, + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=10, + obs_builder_object=GlobalObsForRailEnv(), + stochastic_data=stochastic_data, # Malfunction data generator + ) + + # Tree Observation + env2 = RailEnv(width=25, + height=30, + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=10, + obs_builder_object=GlobalObsForRailEnv(), + stochastic_data=stochastic_data, # Malfunction data generator + ) + + env.reset(True, False, True, random_seed=tests) + env2.reset(True, False, True, random_seed=tests) + + # Check that both environments produce the same initial start positions + assert env.agents[0].initial_position == env2.agents[0].initial_position + assert env.agents[1].initial_position == env2.agents[1].initial_position + assert env.agents[2].initial_position == env2.agents[2].initial_position + assert env.agents[3].initial_position == env2.agents[3].initial_position + assert env.agents[4].initial_position == env2.agents[4].initial_position + assert env.agents[5].initial_position == env2.agents[5].initial_position + assert env.agents[6].initial_position == env2.agents[6].initial_position + assert env.agents[7].initial_position == env2.agents[7].initial_position + assert env.agents[8].initial_position == env2.agents[8].initial_position + assert env.agents[9].initial_position == env2.agents[9].initial_position + + action_dict = {} + for step in range(10): + for a in range(env.get_num_agents()): + action = np.random.randint(4) + action_dict[a] = action + # print("----------------------") + # print(env.agents[a].malfunction_data, env.agents[a].status) + # print(env2.agents[a].malfunction_data, env2.agents[a].status) + + env.step(action_dict) + env2.step(action_dict) + + # Check that both environments end up in the same position + + assert env.agents[0].position == env2.agents[0].position + assert env.agents[1].position == env2.agents[1].position + assert env.agents[2].position == env2.agents[2].position + assert env.agents[3].position == env2.agents[3].position + assert env.agents[4].position == env2.agents[4].position + assert env.agents[5].position == env2.agents[5].position + assert env.agents[6].position == env2.agents[6].position + assert env.agents[7].position == env2.agents[7].position + assert env.agents[8].position == env2.agents[8].position + assert env.agents[9].position == env2.agents[9].position diff --git a/tests/test_speed_classes.py b/tests/test_speed_classes.py index ff5ee56a308ce19559d079b716bde90ad65baf11..1fcf3b3ef0b7cc176d345fa547e91eeeef0a05bd 100644 --- a/tests/test_speed_classes.py +++ b/tests/test_speed_classes.py @@ -21,7 +21,7 @@ def test_rail_env_speed_intializer(): env = RailEnv(width=50, height=50, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, - seed=0), + seed=1), schedule_generator=complex_schedule_generator(), number_of_agents=10) env.reset()