Commit 215595bb authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

Merge branch 'test-fixes' into env-step-facelift

parents 6df9e4d3 a186d58b
......@@ -84,11 +84,6 @@ class SparseLineGen(BaseLineGen):
train_stations = hints['train_stations']
city_positions = hints['city_positions']
city_orientation = hints['city_orientations']
max_num_agents = hints['num_agents']
city_orientations = hints['city_orientations']
if num_agents > max_num_agents:
num_agents = max_num_agents
warnings.warn("Too many agents! Changes number of agents.")
# Place agents and targets within available train stations
agents_position = []
agents_target = []
......
......@@ -48,8 +48,7 @@ def make_simple_rail() -> Tuple[GridTransitionMap, np.array]:
[( (6, 6), 0 ) ],
]
city_orientations = [0, 2]
agents_hints = {'num_agents': 2,
'city_positions': city_positions,
agents_hints = {'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
......@@ -100,8 +99,7 @@ def make_disconnected_simple_rail() -> Tuple[GridTransitionMap, np.array]:
[( (6, 6), 0 ) ],
]
city_orientations = [0, 2]
agents_hints = {'num_agents': 2,
'city_positions': city_positions,
agents_hints = {'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
......@@ -149,8 +147,7 @@ def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]:
[( (6, 6), 0 ) ],
]
city_orientations = [0, 2]
agents_hints = {'num_agents': 2,
'city_positions': city_positions,
agents_hints = {'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
......@@ -199,8 +196,7 @@ def make_simple_rail_unconnected() -> Tuple[GridTransitionMap, np.array]:
[( (6, 6), 0 ) ],
]
city_orientations = [0, 2]
agents_hints = {'num_agents': 2,
'city_positions': city_positions,
agents_hints = {'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
......@@ -255,8 +251,7 @@ def make_simple_rail_with_alternatives() -> Tuple[GridTransitionMap, np.array]:
[( (6, 6), 0 ) ],
]
city_orientations = [0, 2]
agents_hints = {'num_agents': 2,
'city_positions': city_positions,
agents_hints = {'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
......@@ -306,8 +301,43 @@ def make_invalid_simple_rail() -> Tuple[GridTransitionMap, np.array]:
[( (6, 6), 0 ) ],
]
city_orientations = [0, 2]
agents_hints = {'num_agents': 2,
'city_positions': city_positions,
agents_hints = {'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
return rail, rail_map, optionals
def make_oval_rail() -> Tuple[GridTransitionMap, np.array]:
transitions = RailEnvTransitions()
cells = transitions.transition_list
empty = cells[0]
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
right_turn_from_south = cells[8]
right_turn_from_west = transitions.rotate_transition(right_turn_from_south, 90)
right_turn_from_north = transitions.rotate_transition(right_turn_from_south, 180)
right_turn_from_east = transitions.rotate_transition(right_turn_from_south, 270)
rail_map = np.array(
[[empty] * 9] +
[[empty] + [right_turn_from_south] + [horizontal_straight] * 5 + [right_turn_from_west] + [empty]] +
[[empty] + [vertical_straight] + [empty] * 5 + [vertical_straight] + [empty]]+
[[empty] + [vertical_straight] + [empty] * 5 + [vertical_straight] + [empty]] +
[[empty] + [right_turn_from_east] + [horizontal_straight] * 5 + [right_turn_from_north] + [empty]] +
[[empty] * 9], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
city_positions = [(1, 4), (4, 4)]
train_stations = [
[((1, 4), 0)],
[((4, 4), 0)],
]
city_orientations = [1, 3]
agents_hints = {'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
......
import pytest
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_oval_rail
def test_shortest_paths():
rail, rail_map, optionals = make_oval_rail()
speed_ratio_map = {1.: 1.0}
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(speed_ratio_map),
number_of_agents=2)
env.reset()
agent0_shortest_path = env.agents[0].get_shortest_path(env.distance_map)
agent1_shortest_path = env.agents[1].get_shortest_path(env.distance_map)
assert len(agent0_shortest_path) == 10
assert len(agent1_shortest_path) == 10
def test_travel_time_on_shortest_paths():
rail, rail_map, optionals = make_oval_rail()
speed_ratio_map = {1.: 1.0}
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(speed_ratio_map),
number_of_agents=2)
env.reset()
agent0_travel_time = env.agents[0].get_travel_time_on_shortest_path(env.distance_map)
agent1_travel_time = env.agents[1].get_travel_time_on_shortest_path(env.distance_map)
assert agent0_travel_time == 10
assert agent1_travel_time == 10
speed_ratio_map = {1/2: 1.0}
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(speed_ratio_map),
number_of_agents=2)
env.reset()
agent0_travel_time = env.agents[0].get_travel_time_on_shortest_path(env.distance_map)
agent1_travel_time = env.agents[1].get_travel_time_on_shortest_path(env.distance_map)
assert agent0_travel_time == 20
assert agent1_travel_time == 20
speed_ratio_map = {1/3: 1.0}
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(speed_ratio_map),
number_of_agents=2)
env.reset()
agent0_travel_time = env.agents[0].get_travel_time_on_shortest_path(env.distance_map)
agent1_travel_time = env.agents[1].get_travel_time_on_shortest_path(env.distance_map)
assert agent0_travel_time == 30
assert agent1_travel_time == 30
speed_ratio_map = {1/4: 1.0}
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(speed_ratio_map),
number_of_agents=2)
env.reset()
agent0_travel_time = env.agents[0].get_travel_time_on_shortest_path(env.distance_map)
agent1_travel_time = env.agents[1].get_travel_time_on_shortest_path(env.distance_map)
assert agent0_travel_time == 40
assert agent1_travel_time == 40
# def test_latest_arrival_validity():
# pass
# def test_time_remaining_until_latest_arrival():
# pass
def main():
pass
if __name__ == "__main__":
main()
import numpy as np
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_simple_rail
from flatland.envs.persistence import RailEnvPersister
def test_load_new():
filename = "test_load_new.pkl"
rail, rail_map, optionals = make_simple_rail()
n_agents = 2
env_initial = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=n_agents)
env_initial.reset(False, False)
rails_initial = env_initial.rail.grid
agents_initial = env_initial.agents
RailEnvPersister.save(env_initial, filename)
env_loaded, _ = RailEnvPersister.load_new(filename)
rails_loaded = env_loaded.rail.grid
agents_loaded = env_loaded.agents
assert np.all(np.array_equal(rails_initial, rails_loaded))
assert agents_initial == agents_loaded
def main():
pass
if __name__ == "__main__":
main()
......@@ -373,9 +373,13 @@ def test_rail_env_reset():
env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
line_generator=line_from_file(file_name), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env3.reset(False, True, False)
env3.reset(False, True)
rails_loaded = env3.rail.grid
agents_loaded = env3.agents
# override `earliest_departure` & `latest_arrival` since they aren't expected to be the same
for agent_initial, agent_loaded in zip(agents_initial, agents_loaded):
agent_loaded.earliest_departure = agent_initial.earliest_departure
agent_loaded.latest_arrival = agent_initial.latest_arrival
assert np.all(np.array_equal(rails_initial, rails_loaded))
assert agents_initial == agents_loaded
......@@ -383,16 +387,21 @@ def test_rail_env_reset():
env4 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
line_generator=line_from_file(file_name), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env4.reset(True, False, False)
env4.reset(True, False)
rails_loaded = env4.rail.grid
agents_loaded = env4.agents
# override `earliest_departure` & `latest_arrival` since they aren't expected to be the same
for agent_initial, agent_loaded in zip(agents_initial, agents_loaded):
agent_loaded.earliest_departure = agent_initial.earliest_departure
agent_loaded.latest_arrival = agent_initial.latest_arrival
assert np.all(np.array_equal(rails_initial, rails_loaded))
assert agents_initial == agents_loaded
def main():
test_rail_environment_single_agent(show=True)
# test_rail_environment_single_agent(show=True)
test_rail_env_reset()
if __name__=="__main__":
main()
......@@ -72,6 +72,10 @@ def tests_rail_from_file():
env.reset()
rails_loaded = env.rail.grid
agents_loaded = env.agents
# override `earliest_departure` & `latest_arrival` since they aren't expected to be the same
for agent_initial, agent_loaded in zip(agents_initial, agents_loaded):
agent_loaded.earliest_departure = agent_initial.earliest_departure
agent_loaded.latest_arrival = agent_initial.latest_arrival
assert np.all(np.array_equal(rails_initial, rails_loaded))
assert agents_initial == agents_loaded
......@@ -85,7 +89,7 @@ def tests_rail_from_file():
file_name_2 = "test_without_distance_map.pkl"
env2 = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail), line_generator=sparse_line_generator(),
rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(),
number_of_agents=3, obs_builder_object=GlobalObsForRailEnv())
env2.reset()
#env2.save(file_name_2)
......@@ -100,6 +104,10 @@ def tests_rail_from_file():
env2.reset()
rails_loaded_2 = env2.rail.grid
agents_loaded_2 = env2.agents
# override `earliest_departure` & `latest_arrival` since they aren't expected to be the same
for agent_initial, agent_loaded in zip(agents_initial_2, agents_loaded_2):
agent_loaded.earliest_departure = agent_initial.earliest_departure
agent_loaded.latest_arrival = agent_initial.latest_arrival
assert np.all(np.array_equal(rails_initial_2, rails_loaded_2))
assert agents_initial_2 == agents_loaded_2
......@@ -113,6 +121,10 @@ def tests_rail_from_file():
env3.reset()
rails_loaded_3 = env3.rail.grid
agents_loaded_3 = env3.agents
# override `earliest_departure` & `latest_arrival` since they aren't expected to be the same
for agent_initial, agent_loaded in zip(agents_initial, agents_loaded_3):
agent_loaded.earliest_departure = agent_initial.earliest_departure
agent_loaded.latest_arrival = agent_initial.latest_arrival
assert np.all(np.array_equal(rails_initial, rails_loaded_3))
assert agents_initial == agents_loaded_3
......@@ -130,6 +142,10 @@ def tests_rail_from_file():
env4.reset()
rails_loaded_4 = env4.rail.grid
agents_loaded_4 = env4.agents
# override `earliest_departure` & `latest_arrival` since they aren't expected to be the same
for agent_initial, agent_loaded in zip(agents_initial_2, agents_loaded_4):
agent_loaded.earliest_departure = agent_initial.earliest_departure
agent_loaded.latest_arrival = agent_initial.latest_arrival
# Check that no distance map was saved
assert not hasattr(env2.obs_builder, "distance_map")
......@@ -139,3 +155,10 @@ def tests_rail_from_file():
# Check that distance map was generated with correct shape
assert env4.distance_map.get() is not None
assert np.shape(env4.distance_map.get()) == dist_map_shape
def main():
tests_rail_from_file()
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment