Skip to content
Snippets Groups Projects
Commit 5bfcb17d authored by u229589's avatar u229589 Committed by spmohanty
Browse files

add test for reset() (False, False), (False, True), (True, False) (issue #250)

parent 4d2ea844
No related branches found
No related tags found
No related merge requests found
......@@ -9,9 +9,9 @@ from flatland.envs.agent_utils import EnvAgentStatic
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.rail_generators import complex_rail_generator, rail_from_file
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator
from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator, schedule_from_file
from flatland.utils.simple_rail import make_simple_rail
"""Tests for `flatland` package."""
......@@ -228,3 +228,78 @@ def test_get_entry_directions():
# nowhere
_assert((0, 0), [False, False, False, False])
def test_rail_env_reset():
file_name = "test_rail_env_reset.pkl"
# Test to save and load file with distance map.
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=3,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
env.reset()
env.save(file_name)
dist_map_shape = np.shape(env.distance_map.get())
# initialize agents_static
rails_initial = env.rail.grid
agents_initial = env.agents
env2 = RailEnv(width=1,
height=1,
rail_generator=rail_from_file(file_name),
schedule_generator=schedule_from_file(file_name),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
env2.reset(False, False, False)
rails_loaded = env2.rail.grid
agents_loaded = env2.agents
assert np.all(np.array_equal(rails_initial, rails_loaded))
assert agents_initial == agents_loaded
# Check that distance map was not recomputed
assert np.shape(env2.distance_map.get()) == dist_map_shape
assert env2.distance_map.get() is not None
env3 = RailEnv(width=1,
height=1,
rail_generator=rail_from_file(file_name),
schedule_generator=schedule_from_file(file_name),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
env3.reset(False, True, False)
rails_loaded = env3.rail.grid
agents_loaded = env3.agents
assert np.all(np.array_equal(rails_initial, rails_loaded))
assert agents_initial == agents_loaded
# Check that distance map was not recomputed
assert np.shape(env3.distance_map.get()) == dist_map_shape
assert env3.distance_map.get() is not None
env4 = RailEnv(width=1,
height=1,
rail_generator=rail_from_file(file_name),
schedule_generator=schedule_from_file(file_name),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
env4.reset(True, False, False)
rails_loaded = env4.rail.grid
agents_loaded = env4.agents
assert np.all(np.array_equal(rails_initial, rails_loaded))
assert agents_initial == agents_loaded
# Check that distance map was not recomputed
assert np.shape(env4.distance_map.get()) == dist_map_shape
assert env4.distance_map.get() is not None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment