Commit 3dd3a53e authored by Erik Nygren's avatar Erik Nygren
Browse files

updated to include comments by christian

parent ba0ae4f5
Pipeline #1493 passed with stage
in 9 minutes and 7 seconds
......@@ -232,7 +232,7 @@ def rail_from_file(filename):
return generator
def rail_from_GridTransitionMap_generator(rail_map):
def rail_from_grid_transition_map(rail_map):
"""
Utility to convert a rail given by a GridTransitionMap map with the correct
16-bit transitions specifications.
......
......@@ -93,7 +93,7 @@ class RailEnv(Environment):
starting positions, targets, and initial orientations for agent handle.
Implemented functions are:
random_rail_generator : generate a random rail of given size
rail_from_GridTransitionMap_generator(rail_map) : generate a rail from
rail_from_grid_transition_map(rail_map) : generate a rail from
a GridTransitionMap object
rail_from_manual_sp ecifications_generator(rail_spec) : generate a rail from
a rail specifications array
......
......@@ -2,7 +2,7 @@ import numpy as np
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.generators import rail_from_grid_transition_map
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
......@@ -33,7 +33,7 @@ def test_walker():
rail.grid = rail_map
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
rail_generator=rail_from_grid_transition_map(rail),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2,
predictor=ShortestPathPredictorForRailEnv(max_depth=10)),
......
......@@ -5,7 +5,7 @@ import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.generators import rail_from_grid_transition_map
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
......@@ -20,7 +20,7 @@ def test_global_obs():
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
rail_generator=rail_from_grid_transition_map(rail),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
......@@ -89,7 +89,7 @@ def test_reward_function_conflict(rendering=False):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
rail_generator=rail_from_grid_transition_map(rail),
number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
......@@ -167,7 +167,7 @@ def test_reward_function_waiting(rendering=False):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
rail_generator=rail_from_grid_transition_map(rail),
number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
......
......@@ -5,7 +5,7 @@ import pprint
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.generators import rail_from_grid_transition_map
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
......@@ -20,7 +20,7 @@ def test_dummy_predictor(rendering=False):
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
rail_generator=rail_from_grid_transition_map(rail),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
)
......@@ -110,7 +110,7 @@ def test_shortest_path_predictor(rendering=False):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
rail_generator=rail_from_grid_transition_map(rail),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
......@@ -229,7 +229,7 @@ def test_shortest_path_predictor_conflicts(rendering=False):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
rail_generator=rail_from_grid_transition_map(rail),
number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
......
......@@ -8,7 +8,7 @@ from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.agent_utils import EnvAgentStatic
from flatland.envs.generators import complex_rail_generator
from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.generators import rail_from_grid_transition_map
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv
......@@ -85,7 +85,7 @@ def test_rail_environment_single_agent():
rail.grid = rail_map
rail_env = RailEnv(width=3,
height=3,
rail_generator=rail_from_GridTransitionMap_generator(rail),
rail_generator=rail_from_grid_transition_map(rail),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
......@@ -164,7 +164,7 @@ def test_dead_end():
rail.grid = rail_map
rail_env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
rail_generator=rail_from_grid_transition_map(rail),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
......@@ -208,7 +208,7 @@ def test_dead_end():
rail.grid = rail_map
rail_env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
rail_generator=rail_from_grid_transition_map(rail),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
......
......@@ -3,7 +3,7 @@
import numpy as np
from flatland.envs.generators import rail_from_GridTransitionMap_generator, rail_from_file, complex_rail_generator, \
from flatland.envs.generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \
random_rail_generator, empty_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
......@@ -30,8 +30,6 @@ def test_empty_rail_generator():
# Check that no agents where placed
assert env.get_num_agents() == 0
return
def test_random_rail_generator():
np.random.seed(0)
......@@ -48,8 +46,6 @@ def test_random_rail_generator():
assert env.rail.grid.shape == (y_dim, x_dim)
assert env.get_num_agents() == n_agents
return
def test_complex_rail_generator():
n_agents = 10
......@@ -65,6 +61,7 @@ def test_complex_rail_generator():
rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist)
)
assert env.get_num_agents() == 2
assert env.rail.grid.shape == (y_dim, x_dim)
min_dist = 2 * x_dim
......@@ -75,6 +72,7 @@ def test_complex_rail_generator():
rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist)
)
assert env.get_num_agents() == 0
assert env.rail.grid.shape == (y_dim, x_dim)
# Check that everything stays the same when correct parameters are given
min_dist = 2
......@@ -87,15 +85,16 @@ def test_complex_rail_generator():
rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist)
)
assert env.get_num_agents() == n_agents
return
assert env.rail.grid.shape == (y_dim, x_dim)
def test_rail_from_GridTransitionMap_generator():
def test_rail_from_grid_transition_map():
rail, rail_map = make_simple_rail()
n_agents = 3
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
rail_generator=rail_from_grid_transition_map(rail),
number_of_agents=n_agents
)
nr_rail_elements = np.count_nonzero(env.rail.grid)
......@@ -105,7 +104,8 @@ def test_rail_from_GridTransitionMap_generator():
# Check that agents are placed on a rail
for a in env.agents:
assert env.rail.grid[a.position] != 0
return
assert env.get_num_agents() == n_agents
def tests_rail_from_file():
......@@ -113,7 +113,7 @@ def tests_rail_from_file():
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
rail_generator=rail_from_grid_transition_map(rail),
number_of_agents=3,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
......@@ -134,4 +134,3 @@ def tests_rail_from_file():
assert np.all(np.array_equal(rails_initial, rails_loaded))
assert agents_initial == agents_loaded
return
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment