Skip to content
Snippets Groups Projects
Commit 2c17b423 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Merge branch '102_tests_for_generators' into 'master'

102 tests for generators

Closes #102

See merge request flatland/flatland!104
parents da6cc6b5 3dd3a53e
No related branches found
No related tags found
No related merge requests found
......@@ -43,6 +43,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=
"""
def generator(width, height, num_agents, num_resets=0):
if num_agents > nr_start_goal:
num_agents = nr_start_goal
print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents")
......@@ -108,7 +109,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=
break
if not all_ok:
# we can might as well give up at this point
# we might as well give up at this point
break
new_path = connect_rail(rail_trans, rail_array, start, goal)
......@@ -231,7 +232,7 @@ def rail_from_file(filename):
return generator
def rail_from_GridTransitionMap_generator(rail_map):
def rail_from_grid_transition_map(rail_map):
"""
Utility to convert a rail given by a GridTransitionMap map with the correct
16-bit transitions specifications.
......
......@@ -93,7 +93,7 @@ class RailEnv(Environment):
starting positions, targets, and initial orientations for agent handle.
Implemented functions are:
random_rail_generator : generate a random rail of given size
rail_from_GridTransitionMap_generator(rail_map) : generate a rail from
rail_from_grid_transition_map(rail_map) : generate a rail from
a GridTransitionMap object
rail_from_manual_sp ecifications_generator(rail_spec) : generate a rail from
a rail specifications array
......
......@@ -2,7 +2,7 @@ import numpy as np
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.generators import rail_from_grid_transition_map
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
......@@ -33,7 +33,7 @@ def test_walker():
rail.grid = rail_map
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
rail_generator=rail_from_grid_transition_map(rail),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2,
predictor=ShortestPathPredictorForRailEnv(max_depth=10)),
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
from flatland.envs.generators import rail_from_GridTransitionMap_generator, rail_from_file
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from tests.simple_rail import make_simple_rail
def test_load_pkl():
file_name = "test_pkl.pkl"
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=3,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
env.save(file_name)
# initialize agents_static
rails_initial = env.rail.grid
agents_initial = env.agents
env = RailEnv(width=1,
height=1,
rail_generator=rail_from_file(file_name),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
rails_loaded = env.rail.grid
agents_loaded = env.agents
assert np.all(np.array_equal(rails_initial, rails_loaded))
assert agents_initial == agents_loaded
return
......@@ -5,7 +5,7 @@ import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.generators import rail_from_grid_transition_map
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
......@@ -20,7 +20,7 @@ def test_global_obs():
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
rail_generator=rail_from_grid_transition_map(rail),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
......@@ -89,7 +89,7 @@ def test_reward_function_conflict(rendering=False):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
rail_generator=rail_from_grid_transition_map(rail),
number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
......@@ -167,7 +167,7 @@ def test_reward_function_waiting(rendering=False):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
rail_generator=rail_from_grid_transition_map(rail),
number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
......
......@@ -5,7 +5,7 @@ import pprint
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.generators import rail_from_grid_transition_map
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
......@@ -20,7 +20,7 @@ def test_dummy_predictor(rendering=False):
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
rail_generator=rail_from_grid_transition_map(rail),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
)
......@@ -110,7 +110,7 @@ def test_shortest_path_predictor(rendering=False):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
rail_generator=rail_from_grid_transition_map(rail),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
......@@ -229,7 +229,7 @@ def test_shortest_path_predictor_conflicts(rendering=False):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
rail_generator=rail_from_grid_transition_map(rail),
number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
......
......@@ -8,7 +8,7 @@ from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.agent_utils import EnvAgentStatic
from flatland.envs.generators import complex_rail_generator
from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.generators import rail_from_grid_transition_map
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv
......@@ -85,7 +85,7 @@ def test_rail_environment_single_agent():
rail.grid = rail_map
rail_env = RailEnv(width=3,
height=3,
rail_generator=rail_from_GridTransitionMap_generator(rail),
rail_generator=rail_from_grid_transition_map(rail),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
......@@ -164,7 +164,7 @@ def test_dead_end():
rail.grid = rail_map
rail_env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
rail_generator=rail_from_grid_transition_map(rail),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
......@@ -208,7 +208,7 @@ def test_dead_end():
rail.grid = rail_map
rail_env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
rail_generator=rail_from_grid_transition_map(rail),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
from flatland.envs.generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \
random_rail_generator, empty_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from tests.simple_rail import make_simple_rail
def test_empty_rail_generator():
np.random.seed(0)
n_agents = 1
x_dim = 5
y_dim = 10
# Check that a random level at with correct parameters is generated
env = RailEnv(width=x_dim,
height=y_dim,
number_of_agents=n_agents,
rail_generator=empty_rail_generator()
)
# Check the dimensions
assert env.rail.grid.shape == (y_dim, x_dim)
# Check that no grid was generated
assert np.count_nonzero(env.rail.grid) == 0
# Check that no agents where placed
assert env.get_num_agents() == 0
def test_random_rail_generator():
np.random.seed(0)
n_agents = 1
x_dim = 5
y_dim = 10
# Check that a random level at with correct parameters is generated
env = RailEnv(width=x_dim,
height=y_dim,
number_of_agents=n_agents,
rail_generator=random_rail_generator()
)
assert env.rail.grid.shape == (y_dim, x_dim)
assert env.get_num_agents() == n_agents
def test_complex_rail_generator():
n_agents = 10
n_start = 2
x_dim = 10
y_dim = 10
min_dist = 4
# Check that agent number is changed to fit generated level
env = RailEnv(width=x_dim,
height=y_dim,
number_of_agents=n_agents,
rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist)
)
assert env.get_num_agents() == 2
assert env.rail.grid.shape == (y_dim, x_dim)
min_dist = 2 * x_dim
# Check that no agents are generated when level cannot be generated
env = RailEnv(width=x_dim,
height=y_dim,
number_of_agents=n_agents,
rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist)
)
assert env.get_num_agents() == 0
assert env.rail.grid.shape == (y_dim, x_dim)
# Check that everything stays the same when correct parameters are given
min_dist = 2
n_start = 5
n_agents = 5
env = RailEnv(width=x_dim,
height=y_dim,
number_of_agents=n_agents,
rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist)
)
assert env.get_num_agents() == n_agents
assert env.rail.grid.shape == (y_dim, x_dim)
def test_rail_from_grid_transition_map():
rail, rail_map = make_simple_rail()
n_agents = 3
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
number_of_agents=n_agents
)
nr_rail_elements = np.count_nonzero(env.rail.grid)
# Check if the number of non-empty rail cells is ok
assert nr_rail_elements == 16
# Check that agents are placed on a rail
for a in env.agents:
assert env.rail.grid[a.position] != 0
assert env.get_num_agents() == n_agents
def tests_rail_from_file():
file_name = "test_pkl.pkl"
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
number_of_agents=3,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
env.save(file_name)
# initialize agents_static
rails_initial = env.rail.grid
agents_initial = env.agents
env = RailEnv(width=1,
height=1,
rail_generator=rail_from_file(file_name),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
rails_loaded = env.rail.grid
agents_loaded = env.agents
assert np.all(np.array_equal(rails_initial, rails_loaded))
assert agents_initial == agents_loaded
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment