Commit 297b65c5 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

some tests working

parent 8f561eb9
......@@ -109,7 +109,11 @@ class EnvAgent:
return get_shortest_paths(distance_map=distance_map, agent_handle=self.handle)[self.handle]
def get_travel_time_on_shortest_path(self, distance_map) -> int:
distance = len(self.get_shortest_path(distance_map))
shortest_path = self.get_shortest_path(distance_map)
if shortest_path is not None:
distance = len(shortest_path)
else:
distance = 0
speed = self.speed_data['speed']
return int(np.ceil(distance / speed))
......
......@@ -21,7 +21,7 @@ from flatland.envs.distance_map import DistanceMap
# cannot import objects / classes directly because of circular import
from flatland.envs import malfunction_generators as mal_gen
from flatland.envs import rail_generators as rail_gen
from flatland.envs import schedule_generators as sched_gen
from flatland.envs import line_generators as line_gen
msgpack_numpy.patch()
......@@ -122,7 +122,7 @@ class RailEnvPersister(object):
width=width, height=height,
rail_generator=rail_gen.rail_from_file(filename,
load_from_package=load_from_package),
schedule_generator=sched_gen.schedule_from_file(filename,
line_generator=line_gen.line_from_file(filename,
load_from_package=load_from_package),
#malfunction_generator_and_process_data=mal_gen.malfunction_from_file(filename,
# load_from_package=load_from_package),
......
......@@ -34,25 +34,29 @@ def test_action_plan(rendering: bool = False):
for handle, agent in enumerate(env.agents):
print("[{}] {} -> {}".format(handle, agent.initial_position, agent.target))
chosen_path_dict = {0: [TrainrunWaypoint(lined_at=0, waypoint=Waypoint(position=(3, 0), direction=3)),
TrainrunWaypoint(lined_at=2, waypoint=Waypoint(position=(3, 1), direction=1)),
TrainrunWaypoint(lined_at=3, waypoint=Waypoint(position=(3, 2), direction=1)),
TrainrunWaypoint(lined_at=14, waypoint=Waypoint(position=(3, 3), direction=1)),
TrainrunWaypoint(lined_at=15, waypoint=Waypoint(position=(3, 4), direction=1)),
TrainrunWaypoint(lined_at=16, waypoint=Waypoint(position=(3, 5), direction=1)),
TrainrunWaypoint(lined_at=17, waypoint=Waypoint(position=(3, 6), direction=1)),
TrainrunWaypoint(lined_at=18, waypoint=Waypoint(position=(3, 7), direction=1)),
TrainrunWaypoint(lined_at=19, waypoint=Waypoint(position=(3, 8), direction=1)),
TrainrunWaypoint(lined_at=20, waypoint=Waypoint(position=(3, 8), direction=5))],
1: [TrainrunWaypoint(lined_at=0, waypoint=Waypoint(position=(3, 8), direction=3)),
TrainrunWaypoint(lined_at=3, waypoint=Waypoint(position=(3, 7), direction=3)),
TrainrunWaypoint(lined_at=5, waypoint=Waypoint(position=(3, 6), direction=3)),
TrainrunWaypoint(lined_at=7, waypoint=Waypoint(position=(3, 5), direction=3)),
TrainrunWaypoint(lined_at=9, waypoint=Waypoint(position=(3, 4), direction=3)),
TrainrunWaypoint(lined_at=11, waypoint=Waypoint(position=(3, 3), direction=3)),
TrainrunWaypoint(lined_at=13, waypoint=Waypoint(position=(2, 3), direction=0)),
TrainrunWaypoint(lined_at=15, waypoint=Waypoint(position=(1, 3), direction=0)),
TrainrunWaypoint(lined_at=17, waypoint=Waypoint(position=(0, 3), direction=0))]}
# Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
for _ in range(max([agent.earliest_departure for agent in env.agents])):
env.step({}) # DO_NOTHING for all agents
chosen_path_dict = {0: [TrainrunWaypoint(scheduled_at=0, waypoint=Waypoint(position=(3, 0), direction=3)),
TrainrunWaypoint(scheduled_at=2, waypoint=Waypoint(position=(3, 1), direction=1)),
TrainrunWaypoint(scheduled_at=3, waypoint=Waypoint(position=(3, 2), direction=1)),
TrainrunWaypoint(scheduled_at=14, waypoint=Waypoint(position=(3, 3), direction=1)),
TrainrunWaypoint(scheduled_at=15, waypoint=Waypoint(position=(3, 4), direction=1)),
TrainrunWaypoint(scheduled_at=16, waypoint=Waypoint(position=(3, 5), direction=1)),
TrainrunWaypoint(scheduled_at=17, waypoint=Waypoint(position=(3, 6), direction=1)),
TrainrunWaypoint(scheduled_at=18, waypoint=Waypoint(position=(3, 7), direction=1)),
TrainrunWaypoint(scheduled_at=19, waypoint=Waypoint(position=(3, 8), direction=1)),
TrainrunWaypoint(scheduled_at=20, waypoint=Waypoint(position=(3, 8), direction=5))],
1: [TrainrunWaypoint(scheduled_at=0, waypoint=Waypoint(position=(3, 8), direction=3)),
TrainrunWaypoint(scheduled_at=3, waypoint=Waypoint(position=(3, 7), direction=3)),
TrainrunWaypoint(scheduled_at=5, waypoint=Waypoint(position=(3, 6), direction=3)),
TrainrunWaypoint(scheduled_at=7, waypoint=Waypoint(position=(3, 5), direction=3)),
TrainrunWaypoint(scheduled_at=9, waypoint=Waypoint(position=(3, 4), direction=3)),
TrainrunWaypoint(scheduled_at=11, waypoint=Waypoint(position=(3, 3), direction=3)),
TrainrunWaypoint(scheduled_at=13, waypoint=Waypoint(position=(2, 3), direction=0)),
TrainrunWaypoint(scheduled_at=15, waypoint=Waypoint(position=(1, 3), direction=0)),
TrainrunWaypoint(scheduled_at=17, waypoint=Waypoint(position=(0, 3), direction=0))]}
expected_action_plan = [[
# take action to enter the grid
ActionPlanElement(0, RailEnvActions.MOVE_FORWARD),
......
......@@ -17,6 +17,11 @@ def test_initial_status():
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
remove_agents_at_target=False)
env.reset()
# Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
for _ in range(max([agent.earliest_departure for agent in env.agents])):
env.step({}) # DO_NOTHING for all agents
set_penalties_for_replay(env)
test_config = ReplayConfig(
replay=[
......@@ -126,6 +131,10 @@ def test_status_done_remove():
remove_agents_at_target=True)
env.reset()
# Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
for _ in range(max([agent.earliest_departure for agent in env.agents])):
env.step({}) # DO_NOTHING for all agents
set_penalties_for_replay(env)
test_config = ReplayConfig(
replay=[
......
......@@ -9,9 +9,9 @@ from flatland.envs.agent_utils import EnvAgent
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import complex_rail_generator, rail_from_file
from flatland.envs.rail_generators import sparse_rail_generator, rail_from_file
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import random_line_generator, complex_line_generator, line_from_file
from flatland.envs.line_generators import random_line_generator, sparse_line_generator, line_from_file
from flatland.utils.simple_rail import make_simple_rail
from flatland.envs.persistence import RailEnvPersister
from flatland.utils.rendertools import RenderTool
......@@ -37,9 +37,10 @@ def test_load_env():
def test_save_load():
env = RailEnv(width=10, height=10,
rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1),
line_generator=complex_line_generator(), number_of_agents=2)
rail_generator=sparse_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1),
line_generator=sparse_line_generator(), number_of_agents=2)
env.reset()
agent_1_pos = env.agents[0].position
agent_1_dir = env.agents[0].direction
agent_1_tar = env.agents[0].target
......
......@@ -23,6 +23,10 @@ def test_get_shortest_paths_unreachable():
obs_builder_object=GlobalObsForRailEnv())
env.reset()
# Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
for _ in range(max([agent.earliest_departure for agent in env.agents])):
env.step({}) # DO_NOTHING for all agents
# set the initial position
agent = env.agents[0]
agent.position = (3, 1) # west dead-end
......@@ -36,7 +40,7 @@ def test_get_shortest_paths_unreachable():
actual = get_shortest_paths(env.distance_map)
expected = {0: None}
assert actual == expected, "actual={},expected={}".format(actual, expected)
assert actual[0] == expected[0], "actual={},expected={}".format(actual[0], expected[0])
# todo file test_002.pkl has to be generated automatically
......
......@@ -1512,7 +1512,7 @@ def test_sparse_generator_changes_to_grid_mode():
rail_env = RailEnv(width=10, height=20, rail_generator=sparse_rail_generator(
max_num_cities=100,
max_rails_between_cities=2,
max_rails_in_city=2,
max_rail_pairs_in_city=1,
seed=15,
grid_mode=False
), line_generator=sparse_line_generator(), number_of_agents=10,
......
......@@ -25,12 +25,14 @@ def test_line_from_file_sparse():
seed=1,
grid_mode=False,
max_rails_between_cities=3,
max_rails_in_city=6,
max_rail_pairs_in_city=3,
)
line_generator = sparse_line_generator(speed_ration_map)
create_and_save_env(file_name="./sparse_env_test.pkl", rail_generator=rail_generator,
env = create_and_save_env(file_name="./sparse_env_test.pkl", rail_generator=rail_generator,
line_generator=line_generator)
old_num_steps = env._max_episode_steps
old_num_agents = len(env.agents)
# Sparse generator
......@@ -41,10 +43,10 @@ def test_line_from_file_sparse():
sparse_env_from_file.reset(True, True)
# Assert loaded agent number is correct
assert sparse_env_from_file.get_num_agents() == 10
assert sparse_env_from_file.get_num_agents() == old_num_agents
# Assert max steps is correct
assert sparse_env_from_file._max_episode_steps == 500
assert sparse_env_from_file._max_episode_steps == old_num_steps
......@@ -65,8 +67,10 @@ def test_line_from_file_random():
rail_generator = random_rail_generator()
line_generator = random_line_generator(speed_ration_map)
create_and_save_env(file_name="./random_env_test.pkl", rail_generator=rail_generator,
env = create_and_save_env(file_name="./random_env_test.pkl", rail_generator=rail_generator,
line_generator=line_generator)
old_num_steps = env._max_episode_steps
old_num_agents = len(env.agents)
# Random generator
......@@ -77,10 +81,10 @@ def test_line_from_file_random():
random_env_from_file.reset(True, True)
# Assert loaded agent number is correct
assert random_env_from_file.get_num_agents() == 10
assert random_env_from_file.get_num_agents() == old_num_agents
# Assert max steps is correct
assert random_env_from_file._max_episode_steps == 1350
assert random_env_from_file._max_episode_steps == old_num_steps
......@@ -105,8 +109,10 @@ def test_line_from_file_complex():
max_dist=99999)
line_generator = complex_line_generator(speed_ration_map)
create_and_save_env(file_name="./complex_env_test.pkl", rail_generator=rail_generator,
env = create_and_save_env(file_name="./complex_env_test.pkl", rail_generator=rail_generator,
line_generator=line_generator)
old_num_steps = env._max_episode_steps
old_num_agents = len(env.agents)
# Load the different envs and check the parameters
......@@ -119,7 +125,7 @@ def test_line_from_file_complex():
complex_env_from_file.reset(True, True)
# Assert loaded agent number is correct
assert complex_env_from_file.get_num_agents() == 10
assert complex_env_from_file.get_num_agents() == old_num_agents
# Assert max steps is correct
assert complex_env_from_file._max_episode_steps == 1350
assert complex_env_from_file._max_episode_steps == old_num_steps
......@@ -97,6 +97,8 @@ def test_malfunction_process():
actions[i] = np.argmax(obs[i]) + 1
obs, all_rewards, done, _ = env.step(actions)
if done["__all__"]:
break
if env.agents[0].malfunction_data['malfunction'] > 0:
agent_malfunctioning = True
......
......@@ -30,6 +30,10 @@ def test_get_global_observation():
obs_builder_object=GlobalObsForRailEnv())
env.reset()
# Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
for _ in range(max([agent.earliest_departure for agent in env.agents])):
env.step({}) # DO_NOTHING for all agents
obs, all_rewards, done, _ = env.step({i: RailEnvActions.MOVE_FORWARD for i in range(number_of_agents)})
for i in range(len(env.agents)):
agent: EnvAgent = env.agents[i]
......
......@@ -51,6 +51,7 @@ def test_multi_speed_init():
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999,
seed=1), line_generator=complex_line_generator(),
number_of_agents=5)
# Initialize the agent with the parameters corresponding to the environment and observation_builder
agent = RandomAgent(218, 4)
......@@ -197,6 +198,12 @@ def test_multispeed_actions_no_malfunction_blocking():
line_generator=random_line_generator(), number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.reset()
# Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
for _ in range(max([agent.earliest_departure for agent in env.agents])):
env.step({}) # DO_NOTHING for all agents
set_penalties_for_replay(env)
test_configs = [
ReplayConfig(
......@@ -381,7 +388,11 @@ def test_multispeed_actions_malfunction_no_blocking():
line_generator=random_line_generator(), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.reset()
# Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
for _ in range(max([agent.earliest_departure for agent in env.agents])):
env.step({}) # DO_NOTHING for all agents
set_penalties_for_replay(env)
test_config = ReplayConfig(
replay=[
......@@ -515,6 +526,10 @@ def test_multispeed_actions_no_malfunction_invalid_actions():
line_generator=random_line_generator(), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.reset()
# Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
for _ in range(max([agent.earliest_departure for agent in env.agents])):
env.step({}) # DO_NOTHING for all agents
set_penalties_for_replay(env)
test_config = ReplayConfig(
......
......@@ -152,4 +152,4 @@ def create_and_save_env(file_name: str, line_generator: LineGenerator, rail_gene
env.reset(True, True)
#env.save(file_name)
RailEnvPersister.save(env, file_name)
return env
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment