Commit d7bc190d authored by u229589's avatar u229589
Browse files

remove call to reset in RailEnv constructor and fix unit tests

parent f3f49ca6
......@@ -217,9 +217,6 @@ class RailEnv(Environment):
self.max_number_of_steps_broken = malfunction_max_duration
# Reset environment
self.reset()
self.num_resets = 0 # yes, set it to zero again!
self.valid_positions = None
def _seed(self, seed=None):
......
......@@ -23,6 +23,7 @@ def test_initial_status():
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
env.reset()
set_penalties_for_replay(env)
test_config = ReplayConfig(
replay=[
......@@ -133,6 +134,7 @@ def test_status_done_remove():
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
remove_agents_at_target=True
)
env.reset()
set_penalties_for_replay(env)
test_config = ReplayConfig(
......
......@@ -19,6 +19,7 @@ from flatland.utils.simple_rail import make_simple_rail
def test_load_env():
env = RailEnv(10, 10)
env.reset()
env.load_resource('env_data.tests', 'test-10x10.mpk')
agent_static = EnvAgentStatic((0, 0), 2, (5, 5), False)
......@@ -83,6 +84,7 @@ def test_rail_environment_single_agent():
schedule_generator=random_schedule_generator(),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
rail_env.reset()
for _ in range(200):
_ = rail_env.reset(False, False, True)
......@@ -204,6 +206,7 @@ def test_get_entry_directions():
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
env.reset()
def _assert(position, expected):
actual = env.get_valid_directions_on_grid(*position)
......
......@@ -21,6 +21,7 @@ def test_get_shortest_paths_unreachable():
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
)
env.reset()
# set the initial position
agent = env.agents_static[0]
......@@ -41,6 +42,7 @@ def test_get_shortest_paths_unreachable():
def test_get_shortest_paths():
env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests')
env.reset()
actual = get_shortest_paths(env.distance_map)
expected = {
......@@ -169,6 +171,7 @@ def test_get_shortest_paths():
def test_get_shortest_paths_max_depth():
env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests')
env.reset()
actual = get_shortest_paths(env.distance_map, max_depth=2)
expected = {
......
......@@ -24,6 +24,7 @@ def test_sparse_rail_generator():
number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv()
)
env.reset()
env.reset(False, False, True)
# for r in range(env.height):
# for c in range (env.width):
......@@ -535,7 +536,8 @@ def test_sparse_rail_generator_deterministic():
number_of_agents=1,
stochastic_data=stochastic_data, # Malfunction data generator
)
# for r in range(env.height):
env.reset()
# for r in range(env.height):
# for c in range(env.width):
# print("assert env.rail.get_full_transitions({}, {}) == {}, \"[{}][{}]\"".format(r, c,
# env.rail.get_full_transitions(
......@@ -1311,6 +1313,7 @@ def test_rail_env_action_required_info():
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
env_always_action.reset()
np.random.seed(0)
random.seed(0)
env_only_if_action_required = RailEnv(width=50,
......@@ -1326,6 +1329,7 @@ def test_rail_env_action_required_info():
number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
env_renderer = RenderTool(env_always_action, gl="PILSVG", )
env_only_if_action_required.reset()
env_always_action.reset(False, False, True)
env_only_if_action_required.reset(False, False, True)
......@@ -1395,6 +1399,7 @@ def test_rail_env_malfunction_speed_info():
number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv(),
stochastic_data=stochastic_data)
env.reset()
env.reset(False, False, True)
env_renderer = RenderTool(env, gl="PILSVG", )
......
......@@ -81,6 +81,7 @@ def test_malfunction_process():
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=SingleAgentNavigationObs()
)
env.reset()
# reset to initialize agents_static
obs, info = env.reset(False, False, True, random_seed=10)
......@@ -150,21 +151,21 @@ def test_malfunction_process_statistically():
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=SingleAgentNavigationObs()
)
env.reset()
# reset to initialize agents_static
env.reset(True, True, False, random_seed=10)
env.agents[0].target = (0, 0)
nb_malfunction = 0
agent_malfunction_list = [[0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0],
agent_malfunction_list = [[0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4, 3, 2],
[0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1],
[0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 6, 5],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4],
[0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 0],
[6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 6, 5, 4],
[6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0],
[6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4, 3],
[0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5],
[0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4, 3, 2, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1],
[6, 6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0]]
[0, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0],
[6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
for step in range(20):
action_dict: Dict[int, RailEnvActions] = {}
......@@ -175,7 +176,6 @@ def test_malfunction_process_statistically():
# agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction'])
assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[agent_idx][step]
env.step(action_dict)
# print(agent_malfunction_list)
def test_malfunction_before_entry():
......@@ -196,6 +196,7 @@ def test_malfunction_before_entry():
random_seed=1,
stochastic_data=stochastic_data, # Malfunction data generator
)
env.reset()
# reset to initialize agents_static
env.reset(False, False, False, random_seed=10)
env.agents[0].target = (0, 0)
......@@ -254,6 +255,7 @@ def test_initial_malfunction():
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=SingleAgentNavigationObs()
)
env.reset()
# reset to initialize agents_static
env.reset(False, False, True, random_seed=10)
......@@ -327,7 +329,7 @@ def test_initial_malfunction_stop_moving():
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=SingleAgentNavigationObs()
)
# reset to initialize agents_static
env.reset()
print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status)
......@@ -532,6 +534,7 @@ def tests_random_interference_from_outside():
random_seed=1,
stochastic_data=stochastic_data, # Malfunction data generator
)
env.reset()
# reset to initialize agents_static
env.agents[0].speed_data['speed'] = 0.33
env.agents[0].initial_position = (3, 0)
......@@ -564,6 +567,7 @@ def tests_random_interference_from_outside():
random_seed=1,
stochastic_data=stochastic_data, # Malfunction data generator
)
env.reset()
# reset to initialize agents_static
env.agents[0].speed_data['speed'] = 0.33
env.agents[0].initial_position = (3, 0)
......
......@@ -42,6 +42,7 @@ def test_render_env(save_new_images=False):
number_of_agents=0,
obs_builder_object=TreeObsForRailEnv(max_depth=2)
)
oEnv.reset()
oEnv.rail.load_transition_map('env_data.tests', "test1.npy")
oRT = rt.RenderTool(oEnv, gl="PILSVG")
oRT.render_env(show=False)
......@@ -50,7 +51,7 @@ def test_render_env(save_new_images=False):
oRT = rt.RenderTool(oEnv, gl="PIL")
oRT.render_env()
checkFrozenImage(oRT, "basic-env-PIL.npz", resave=save_new_images)
def main():
if len(sys.argv) == 2 and sys.argv[1] == "save":
test_render_env(save_new_images=True)
......
......@@ -32,6 +32,7 @@ def test_get_global_observation():
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=number_of_agents, stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=GlobalObsForRailEnv())
env.reset()
obs, all_rewards, done, _ = env.step({i: RailEnvActions.MOVE_FORWARD for i in range(number_of_agents)})
for i in range(len(env.agents)):
......
......@@ -55,6 +55,7 @@ def test_multi_speed_init():
seed=1),
schedule_generator=complex_schedule_generator(),
number_of_agents=5)
env.reset()
# Initialize the agent with the parameters corresponding to the environment and observation_builder
agent = RandomAgent(218, 4)
......@@ -104,6 +105,7 @@ def test_multispeed_actions_no_malfunction_no_blocking():
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
env.reset()
set_penalties_for_replay(env)
test_config = ReplayConfig(
......@@ -207,6 +209,7 @@ def test_multispeed_actions_no_malfunction_blocking():
number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
env.reset()
set_penalties_for_replay(env)
test_configs = [
ReplayConfig(
......@@ -394,6 +397,7 @@ def test_multispeed_actions_malfunction_no_blocking():
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
env.reset()
set_penalties_for_replay(env)
test_config = ReplayConfig(
......@@ -531,6 +535,7 @@ def test_multispeed_actions_no_malfunction_invalid_actions():
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
env.reset()
set_penalties_for_replay(env)
test_config = ReplayConfig(
......
......@@ -21,6 +21,7 @@ def test_random_seeding():
schedule_generator=random_schedule_generator(seed=12),
number_of_agents=10
)
env.reset()
env.reset(True, True, False, random_seed=1)
env.agents[0].target = (0, 0)
......@@ -60,6 +61,7 @@ def test_seeding_and_observations():
number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv()
)
env.reset()
# Tree Observation
env2 = RailEnv(width=25,
height=30,
......@@ -68,6 +70,7 @@ def test_seeding_and_observations():
number_of_agents=10,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
)
env2.reset()
env.reset(False, False, False, random_seed=12)
env2.reset(False, False, False, random_seed=12)
......@@ -127,6 +130,7 @@ def test_seeding_and_malfunction():
obs_builder_object=GlobalObsForRailEnv(),
stochastic_data=stochastic_data, # Malfunction data generator
)
env.reset()
# Tree Observation
env2 = RailEnv(width=25,
......@@ -137,6 +141,7 @@ def test_seeding_and_malfunction():
obs_builder_object=GlobalObsForRailEnv(),
stochastic_data=stochastic_data, # Malfunction data generator
)
env2.reset()
env.reset(True, False, True, random_seed=tests)
env2.reset(True, False, True, random_seed=tests)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment