Skip to content
Snippets Groups Projects
Commit dece6c16 authored by u214892's avatar u214892
Browse files

#141 different agent classes

parent 7f351228
No related branches found
No related tags found
No related merge requests found
Showing
with 311 additions and 204 deletions
...@@ -3,6 +3,7 @@ import random ...@@ -3,6 +3,7 @@ import random
import numpy as np import numpy as np
from flatland.envs.agent_generators import complex_rail_generator_agents_placer
from flatland.envs.generators import complex_rail_generator from flatland.envs.generators import complex_rail_generator
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
...@@ -15,6 +16,7 @@ def run_benchmark(): ...@@ -15,6 +16,7 @@ def run_benchmark():
# Example generate a random rail # Example generate a random rail
env = RailEnv(width=15, height=15, env = RailEnv(width=15, height=15,
rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12), rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12),
agent_generator=complex_rail_generator_agents_placer(),
number_of_agents=5) number_of_agents=5)
n_trials = 20 n_trials = 20
......
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid_utils import coordinate_to_position from flatland.core.grid.grid_utils import coordinate_to_position
from flatland.envs.agent_generators import complex_rail_generator_agents_placer
from flatland.envs.generators import random_rail_generator, complex_rail_generator from flatland.envs.generators import random_rail_generator, complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv
...@@ -20,6 +21,7 @@ class SimpleObs(ObservationBuilder): ...@@ -20,6 +21,7 @@ class SimpleObs(ObservationBuilder):
Simplest observation builder. The object returns observation vectors with 5 identical components, Simplest observation builder. The object returns observation vectors with 5 identical components,
all equal to the ID of the respective agent. all equal to the ID of the respective agent.
""" """
def __init__(self): def __init__(self):
self.observation_space = [5] self.observation_space = [5]
...@@ -53,6 +55,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): ...@@ -53,6 +55,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
will be [1, 0, 0]. will be [1, 0, 0].
""" """
def __init__(self): def __init__(self):
super().__init__(max_depth=0) super().__init__(max_depth=0)
self.observation_space = [3] self.observation_space = [3]
...@@ -90,6 +93,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): ...@@ -90,6 +93,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
env = RailEnv(width=7, env = RailEnv(width=7,
height=7, height=7,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0), rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0),
agent_generator=complex_rail_generator_agents_placer(),
number_of_agents=1, number_of_agents=1,
obs_builder_object=SingleAgentNavigationObs()) obs_builder_object=SingleAgentNavigationObs())
...@@ -97,8 +101,8 @@ obs = env.reset() ...@@ -97,8 +101,8 @@ obs = env.reset()
env_renderer = RenderTool(env, gl="PILSVG") env_renderer = RenderTool(env, gl="PILSVG")
env_renderer.render_env(show=True, frames=True, show_observations=True) env_renderer.render_env(show=True, frames=True, show_observations=True)
for step in range(100): for step in range(100):
action = np.argmax(obs[0])+1 action = np.argmax(obs[0]) + 1
obs, all_rewards, done, _ = env.step({0:action}) obs, all_rewards, done, _ = env.step({0: action})
print("Rewards: ", all_rewards, " [done=", done, "]") print("Rewards: ", all_rewards, " [done=", done, "]")
env_renderer.render_env(show=True, frames=True, show_observations=True) env_renderer.render_env(show=True, frames=True, show_observations=True)
time.sleep(0.1) time.sleep(0.1)
...@@ -200,6 +204,7 @@ CustomObsBuilder = ObservePredictions(CustomPredictor) ...@@ -200,6 +204,7 @@ CustomObsBuilder = ObservePredictions(CustomPredictor)
env = RailEnv(width=10, env = RailEnv(width=10,
height=10, height=10,
rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, seed=0), rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
agent_generator=complex_rail_generator_agents_placer(),
number_of_agents=3, number_of_agents=3,
obs_builder_object=CustomObsBuilder) obs_builder_object=CustomObsBuilder)
......
...@@ -3,6 +3,7 @@ import time ...@@ -3,6 +3,7 @@ import time
import numpy as np import numpy as np
from flatland.envs.agent_generators import complex_rail_generator_agents_placer
from flatland.envs.generators import complex_rail_generator from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
...@@ -11,6 +12,7 @@ from flatland.utils.rendertools import RenderTool ...@@ -11,6 +12,7 @@ from flatland.utils.rendertools import RenderTool
random.seed(1) random.seed(1)
np.random.seed(1) np.random.seed(1)
class SingleAgentNavigationObs(TreeObsForRailEnv): class SingleAgentNavigationObs(TreeObsForRailEnv):
""" """
We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
...@@ -21,6 +23,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): ...@@ -21,6 +23,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
will be [1, 0, 0]. will be [1, 0, 0].
""" """
def __init__(self): def __init__(self):
super().__init__(max_depth=0) super().__init__(max_depth=0)
self.observation_space = [3] self.observation_space = [3]
...@@ -58,6 +61,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): ...@@ -58,6 +61,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
env = RailEnv(width=14, env = RailEnv(width=14,
height=14, height=14,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0), rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0),
agent_generator=complex_rail_generator_agents_placer(),
number_of_agents=2, number_of_agents=2,
obs_builder_object=SingleAgentNavigationObs()) obs_builder_object=SingleAgentNavigationObs())
...@@ -67,11 +71,11 @@ env_renderer.render_env(show=True, frames=True, show_observations=False) ...@@ -67,11 +71,11 @@ env_renderer.render_env(show=True, frames=True, show_observations=False)
for step in range(100): for step in range(100):
actions = {} actions = {}
for i in range(len(obs)): for i in range(len(obs)):
actions[i] = np.argmax(obs[i])+1 actions[i] = np.argmax(obs[i]) + 1
if step%5 == 0: if step % 5 == 0:
print("Agent halts") print("Agent halts")
actions[0] = 4 # Halt actions[0] = 4 # Halt
obs, all_rewards, done, _ = env.step(actions) obs, all_rewards, done, _ = env.step(actions)
if env.agents[0].malfunction_data['malfunction'] > 0: if env.agents[0].malfunction_data['malfunction'] > 0:
...@@ -82,4 +86,3 @@ for step in range(100): ...@@ -82,4 +86,3 @@ for step in range(100):
if done["__all__"]: if done["__all__"]:
break break
env_renderer.close_window() env_renderer.close_window()
...@@ -2,6 +2,7 @@ import random ...@@ -2,6 +2,7 @@ import random
import numpy as np import numpy as np
from flatland.envs.agent_generators import complex_rail_generator_agents_placer
from flatland.envs.generators import complex_rail_generator from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
...@@ -13,6 +14,7 @@ np.random.seed(1) ...@@ -13,6 +14,7 @@ np.random.seed(1)
env = RailEnv(width=7, env = RailEnv(width=7,
height=7, height=7,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
agent_generator=complex_rail_generator_agents_placer(),
number_of_agents=2, number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2)) obs_builder_object=TreeObsForRailEnv(max_depth=2))
......
import numpy as np import numpy as np
from flatland.envs.agent_generators import complex_rail_generator_agents_placer
from flatland.envs.generators import complex_rail_generator from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv, LocalObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv, LocalObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv
...@@ -16,11 +17,13 @@ LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2) ...@@ -16,11 +17,13 @@ LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2)
env = RailEnv(width=20, env = RailEnv(width=20,
height=20, height=20,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
agent_generator=complex_rail_generator_agents_placer(),
obs_builder_object=TreeObservation, obs_builder_object=TreeObservation,
number_of_agents=3) number_of_agents=3)
env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer = RenderTool(env, gl="PILSVG", )
# Import your own Agent or use RLlib to train agents on Flatland # Import your own Agent or use RLlib to train agents on Flatland
# As an example we use a random agent here # As an example we use a random agent here
......
...@@ -2,29 +2,33 @@ ...@@ -2,29 +2,33 @@
"""Console script for flatland.""" """Console script for flatland."""
import sys import sys
import time
import click import click
import numpy as np import numpy as np
import time import redis
from flatland.envs.agent_generators import complex_rail_generator_agents_placer
from flatland.envs.generators import complex_rail_generator from flatland.envs.generators import complex_rail_generator
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
from flatland.evaluators.service import FlatlandRemoteEvaluationService from flatland.evaluators.service import FlatlandRemoteEvaluationService
import redis from flatland.utils.rendertools import RenderTool
@click.command() @click.command()
def demo(args=None): def demo(args=None):
"""Demo script to check installation""" """Demo script to check installation"""
env = RailEnv( env = RailEnv(
width=15, width=15,
height=15, height=15,
rail_generator=complex_rail_generator( rail_generator=complex_rail_generator(
nr_start_goal=10, nr_start_goal=10,
nr_extra=1, nr_extra=1,
min_dist=8, min_dist=8,
max_dist=99999), max_dist=99999),
number_of_agents=5) agent_generator=complex_rail_generator_agents_placer(),
number_of_agents=5)
env._max_episode_steps = int(15 * (env.width + env.height)) env._max_episode_steps = int(15 * (env.width + env.height))
env_renderer = RenderTool(env) env_renderer = RenderTool(env)
...@@ -52,12 +56,12 @@ def demo(args=None): ...@@ -52,12 +56,12 @@ def demo(args=None):
@click.command() @click.command()
@click.option('--tests', @click.option('--tests',
type=click.Path(exists=True), type=click.Path(exists=True),
help="Path to folder containing Flatland tests", help="Path to folder containing Flatland tests",
required=True required=True
) )
@click.option('--service_id', @click.option('--service_id',
default="FLATLAND_RL_SERVICE_ID", default="FLATLAND_RL_SERVICE_ID",
help="Evaluation Service ID. This has to match the service id on the client.", help="Evaluation Service ID. This has to match the service id on the client.",
required=False required=False
...@@ -70,14 +74,14 @@ def evaluator(tests, service_id): ...@@ -70,14 +74,14 @@ def evaluator(tests, service_id):
raise Exception( raise Exception(
"\nRedis server does not seem to be running on your localhost.\n" "\nRedis server does not seem to be running on your localhost.\n"
"Please ensure that you have a redis server running on your localhost" "Please ensure that you have a redis server running on your localhost"
) )
grader = FlatlandRemoteEvaluationService( grader = FlatlandRemoteEvaluationService(
test_env_folder=tests, test_env_folder=tests,
flatland_rl_service_id=service_id, flatland_rl_service_id=service_id,
visualize=False, visualize=False,
verbose=False verbose=False
) )
grader.run() grader.run()
......
"""Agent generators (railway undertaking, "EVU")."""
from typing import Tuple, List, Callable, Mapping, Optional, Any
import msgpack
import numpy as np
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgentStatic
AgentPosition = Tuple[int, int]
AgentGeneratorProduct = Tuple[List[AgentPosition], List[AgentPosition], List[AgentPosition], List[float]]
AgentGenerator = Callable[[GridTransitionMap, int, Optional[Any]], AgentGeneratorProduct]
def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, float] = None) -> List[float]:
"""
Parameters
-------
nb_agents : int
The number of agents to generate a speed for
speed_ratio_map : Mapping[float,float]
A map of speeds mappint to their ratio of appearance. The ratios must sum up to 1.
Returns
-------
List[float]
A list of size nb_agents of speeds with the corresponding probabilistic ratios.
"""
if speed_ratio_map is None:
return [1.0] * nb_agents
nb_classes = len(speed_ratio_map.keys())
speed_ratio_map_as_list: List[Tuple[float, float]] = list(speed_ratio_map.items())
speed_ratios = list(map(lambda t: t[1], speed_ratio_map_as_list))
speeds = list(map(lambda t: t[0], speed_ratio_map_as_list))
return list(map(lambda index: speeds[index], np.random.choice(nb_classes, nb_agents, p=speed_ratios)))
def complex_rail_generator_agents_placer(speed_ratio_map: Mapping[float, float] = None) -> AgentGenerator:
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None):
start_goal = hints['start_goal']
start_dir = hints['start_dir']
agents_position = [sg[0] for sg in start_goal[:num_agents]]
agents_target = [sg[1] for sg in start_goal[:num_agents]]
agents_direction = start_dir[:num_agents]
if speed_ratio_map:
speeds = speed_initialization_helper(num_agents, speed_ratio_map)
else:
speeds = [1.0] * len(agents_position)
return agents_position, agents_direction, agents_target, speeds
return generator
def get_rnd_agents_pos_tgt_dir_on_rail(speed_ratio_map: Mapping[float, float] = None) -> AgentGenerator:
"""
Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target).
Parameters
-------
rail : GridTransitionMap
The railway to place agents on.
num_agents : int
The number of agents to generate a speed for
speed_ratio_map : Mapping[float,float]
A map of speeds mappint to their ratio of appearance. The ratios must sum up to 1.
Returns
-------
Tuple[List[Tuple[int,int]], List[Tuple[int,int]], List[Tuple[int,int]], List[float]]
initial positions, directions, targets speeds
"""
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None):
def _path_exists(rail, start, direction, end):
# BFS - Check if a path exists between the 2 nodes
visited = set()
stack = [(start, direction)]
while stack:
node = stack.pop()
if node[0][0] == end[0] and node[0][1] == end[1]:
return 1
if node not in visited:
visited.add(node)
moves = rail.get_transitions(node[0][0], node[0][1], node[1])
for move_index in range(4):
if moves[move_index]:
stack.append((get_new_position(node[0], move_index),
move_index))
# If cell is a dead-end, append previous node with reversed
# orientation!
nbits = 0
tmp = rail.get_full_transitions(node[0][0], node[0][1])
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1:
stack.append((node[0], (node[1] + 2) % 4))
return 0
valid_positions = []
for r in range(rail.height):
for c in range(rail.width):
if rail.get_full_transitions(r, c) > 0:
valid_positions.append((r, c))
if len(valid_positions) == 0:
return [], [], [], []
re_generate = True
while re_generate:
agents_position = [
valid_positions[i] for i in
np.random.choice(len(valid_positions), num_agents)]
agents_target = [
valid_positions[i] for i in
np.random.choice(len(valid_positions), num_agents)]
# agents_direction must be a direction for which a solution is
# guaranteed.
agents_direction = [0] * num_agents
re_generate = False
for i in range(num_agents):
valid_movements = []
for direction in range(4):
position = agents_position[i]
moves = rail.get_transitions(position[0], position[1], direction)
for move_index in range(4):
if moves[move_index]:
valid_movements.append((direction, move_index))
valid_starting_directions = []
for m in valid_movements:
new_position = get_new_position(agents_position[i], m[1])
if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0],
agents_target[i]):
valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0:
re_generate = True
else:
agents_direction[i] = valid_starting_directions[
np.random.choice(len(valid_starting_directions), 1)[0]]
agents_speed = speed_initialization_helper(num_agents, speed_ratio_map)
return agents_position, agents_direction, agents_target, agents_speed
return generator
def agents_from_file(filename) -> AgentGenerator:
"""
Utility to load pickle file
Parameters
-------
input_file : Pickle file generated by env.save() or editor
Returns
-------
Tuple[List[Tuple[int,int]], List[Tuple[int,int]], List[Tuple[int,int]], List[float]]
initial positions, directions, targets speeds
"""
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None):
with open(filename, "rb") as file_in:
load_data = file_in.read()
data = msgpack.unpackb(load_data, use_list=False)
# agents are always reset as not moving
agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
# setup with loaded data
agents_position = [a.position for a in agents_static]
agents_direction = [a.direction for a in agents_static]
agents_target = [a.target for a in agents_static]
return agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
return generator
from typing import Mapping, Tuple, List, Callable """Rail generators (infrastructure manager, "Infrastrukturbetreiber")."""
from typing import Callable, Tuple, Any, Optional
import msgpack import msgpack
import numpy as np import numpy as np
...@@ -7,12 +8,12 @@ from flatland.core.grid.grid4_utils import get_direction, mirror ...@@ -7,12 +8,12 @@ from flatland.core.grid.grid4_utils import get_direction, mirror
from flatland.core.grid.grid_utils import distance_on_rail from flatland.core.grid.grid_utils import distance_on_rail
from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgentStatic
from flatland.envs.grid4_generators_utils import connect_rail from flatland.envs.grid4_generators_utils import connect_rail
from flatland.envs.grid4_generators_utils import get_rnd_agents_pos_tgt_dir_on_rail
RailGenerator = Callable[[int, int, int, int], Tuple[GridTransitionMap, Optional[Any]]]
def empty_rail_generator():
def empty_rail_generator() -> RailGenerator:
""" """
Returns a generator which returns an empty rail mail with no agents. Returns a generator which returns an empty rail mail with no agents.
Primarily used by the editor Primarily used by the editor
...@@ -24,7 +25,7 @@ def empty_rail_generator(): ...@@ -24,7 +25,7 @@ def empty_rail_generator():
rail_array = grid_map.grid rail_array = grid_map.grid
rail_array.fill(0) rail_array.fill(0)
return grid_map, [], [], [], [] return [grid_map, None]
return generator return generator
...@@ -33,8 +34,7 @@ def complex_rail_generator(nr_start_goal=1, ...@@ -33,8 +34,7 @@ def complex_rail_generator(nr_start_goal=1,
nr_extra=100, nr_extra=100,
min_dist=20, min_dist=20,
max_dist=99999, max_dist=99999,
seed=0, seed=0) -> RailGenerator:
speed_initializer: Callable[[int], List[float]] = None):
""" """
Parameters Parameters
------- -------
...@@ -42,8 +42,6 @@ def complex_rail_generator(nr_start_goal=1, ...@@ -42,8 +42,6 @@ def complex_rail_generator(nr_start_goal=1,
The width (number of cells) of the grid to generate. The width (number of cells) of the grid to generate.
height : int height : int
The height (number of cells) of the grid to generate. The height (number of cells) of the grid to generate.
speed_initializer : Callable[[int], List[float]]
Function that returns a list of speeds for the numer of agents given as argument.
Returns Returns
------- -------
...@@ -56,8 +54,7 @@ def complex_rail_generator(nr_start_goal=1, ...@@ -56,8 +54,7 @@ def complex_rail_generator(nr_start_goal=1,
if num_agents > nr_start_goal: if num_agents > nr_start_goal:
num_agents = nr_start_goal num_agents = nr_start_goal
print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents") print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents")
rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=RailEnvTransitions())
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
rail_array = grid_map.grid rail_array = grid_map.grid
rail_array.fill(0) rail_array.fill(0)
...@@ -81,6 +78,7 @@ def complex_rail_generator(nr_start_goal=1, ...@@ -81,6 +78,7 @@ def complex_rail_generator(nr_start_goal=1,
# - return transition map + list of [start_pos, start_dir, goal_pos] points # - return transition map + list of [start_pos, start_dir, goal_pos] points
# #
rail_trans = grid_map.transitions
start_goal = [] start_goal = []
start_dir = [] start_dir = []
nr_created = 0 nr_created = 0
...@@ -150,15 +148,10 @@ def complex_rail_generator(nr_start_goal=1, ...@@ -150,15 +148,10 @@ def complex_rail_generator(nr_start_goal=1,
if len(new_path) >= 2: if len(new_path) >= 2:
nr_created += 1 nr_created += 1
agents_position = [sg[0] for sg in start_goal[:num_agents]] return grid_map, {'agents_hints': {
agents_target = [sg[1] for sg in start_goal[:num_agents]] 'start_goal': start_goal,
agents_direction = start_dir[:num_agents] 'start_dir': start_dir
}}
if speed_initializer:
speeds = speed_initializer(num_agents)
else:
speeds = [1.0] * len(agents_position)
return grid_map, agents_position, agents_direction, agents_target, speeds
return generator return generator
...@@ -202,22 +195,18 @@ def rail_from_manual_specifications_generator(rail_spec): ...@@ -202,22 +195,18 @@ def rail_from_manual_specifications_generator(rail_spec):
effective_transition_cell = rail_env_transitions.rotate_transition(basic_type_of_cell_, rotation_cell_) effective_transition_cell = rail_env_transitions.rotate_transition(basic_type_of_cell_, rotation_cell_)
rail.set_transitions((r, c), effective_transition_cell) rail.set_transitions((r, c), effective_transition_cell)
agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail( return [rail, None]
rail,
num_agents)
return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
return generator return generator
def rail_from_file(filename): def rail_from_file(filename) -> RailGenerator:
""" """
Utility to load pickle file Utility to load pickle file
Parameters Parameters
------- -------
input_file : Pickle file generated by env.save() or editor filename : Pickle file generated by env.save() or editor
Returns Returns
------- -------
...@@ -235,26 +224,16 @@ def rail_from_file(filename): ...@@ -235,26 +224,16 @@ def rail_from_file(filename):
grid = np.array(data[b"grid"]) grid = np.array(data[b"grid"])
rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions) rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions)
rail.grid = grid rail.grid = grid
# agents are always reset as not moving
agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
# setup with loaded data
agents_position = [a.position for a in agents_static]
agents_direction = [a.direction for a in agents_static]
agents_target = [a.target for a in agents_static]
if b"distance_maps" in data.keys(): if b"distance_maps" in data.keys():
distance_maps = data[b"distance_maps"] distance_maps = data[b"distance_maps"]
if len(distance_maps) > 0: if len(distance_maps) > 0:
return rail, agents_position, agents_direction, agents_target, [1.0] * len( return rail, {'distance_maps': distance_maps}
agents_position), distance_maps return [rail, None]
else:
return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
else:
return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
return generator return generator
def rail_from_grid_transition_map(rail_map): def rail_from_grid_transition_map(rail_map) -> RailGenerator:
""" """
Utility to convert a rail given by a GridTransitionMap map with the correct Utility to convert a rail given by a GridTransitionMap map with the correct
16-bit transitions specifications. 16-bit transitions specifications.
...@@ -271,16 +250,12 @@ def rail_from_grid_transition_map(rail_map): ...@@ -271,16 +250,12 @@ def rail_from_grid_transition_map(rail_map):
""" """
def generator(width, height, num_agents, num_resets=0): def generator(width, height, num_agents, num_resets=0):
agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail( return [rail_map, None]
rail_map,
num_agents)
return rail_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
return generator return generator
def random_rail_generator(cell_type_relative_proportion=[1.0] * 11): def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGenerator:
""" """
Dummy random level generator: Dummy random level generator:
- fill in cells at random in [width-2, height-2] - fill in cells at random in [width-2, height-2]
...@@ -544,31 +519,6 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11): ...@@ -544,31 +519,6 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils) return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
return_rail.grid = tmp_rail return_rail.grid = tmp_rail
agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail( return [return_rail, None]
return_rail,
num_agents)
return return_rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
return generator return generator
def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, float]) -> List[float]:
"""
Parameters
-------
nb_agents : int
The number of agents to generate a speed for
speed_ratio_map : Mapping[float,float]
A map of speeds mappint to their ratio of appearance. The ratios must sum up to 1.
Returns
-------
List[float]
A list of size nb_agents of speeds with the corresponding probabilistic ratios.
"""
nb_classes = len(speed_ratio_map.keys())
speed_ratio_map_as_list: List[Tuple[float, float]] = list(speed_ratio_map.items())
speed_ratios = list(map(lambda t: t[1], speed_ratio_map_as_list))
speeds = list(map(lambda t: t[0], speed_ratio_map_as_list))
return list(map(lambda index: speeds[index], np.random.choice(nb_classes, nb_agents, p=speed_ratios)))
...@@ -5,10 +5,8 @@ Generator functions are functions that take width, height and num_resets as argu ...@@ -5,10 +5,8 @@ Generator functions are functions that take width, height and num_resets as argu
a GridTransitionMap object. a GridTransitionMap object.
""" """
import numpy as np
from flatland.core.grid.grid4_astar import a_star from flatland.core.grid.grid4_astar import a_star
from flatland.core.grid.grid4_utils import get_direction, mirror, get_new_position from flatland.core.grid.grid4_utils import get_direction, mirror
def connect_rail(rail_trans, rail_array, start, end): def connect_rail(rail_trans, rail_array, start, end):
...@@ -55,81 +53,3 @@ def connect_rail(rail_trans, rail_array, start, end): ...@@ -55,81 +53,3 @@ def connect_rail(rail_trans, rail_array, start, end):
current_dir = new_dir current_dir = new_dir
return path return path
def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
"""
Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target).
TODO: add extensive documentation, as users may need this function to simplify their custom level generators.
"""
def _path_exists(rail, start, direction, end):
# BFS - Check if a path exists between the 2 nodes
visited = set()
stack = [(start, direction)]
while stack:
node = stack.pop()
if node[0][0] == end[0] and node[0][1] == end[1]:
return 1
if node not in visited:
visited.add(node)
moves = rail.get_transitions(node[0][0], node[0][1], node[1])
for move_index in range(4):
if moves[move_index]:
stack.append((get_new_position(node[0], move_index),
move_index))
# If cell is a dead-end, append previous node with reversed
# orientation!
nbits = 0
tmp = rail.get_full_transitions(node[0][0], node[0][1])
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1:
stack.append((node[0], (node[1] + 2) % 4))
return 0
valid_positions = []
for r in range(rail.height):
for c in range(rail.width):
if rail.get_full_transitions(r, c) > 0:
valid_positions.append((r, c))
re_generate = True
while re_generate:
agents_position = [
valid_positions[i] for i in
np.random.choice(len(valid_positions), num_agents)]
agents_target = [
valid_positions[i] for i in
np.random.choice(len(valid_positions), num_agents)]
# agents_direction must be a direction for which a solution is
# guaranteed.
agents_direction = [0] * num_agents
re_generate = False
for i in range(num_agents):
valid_movements = []
for direction in range(4):
position = agents_position[i]
moves = rail.get_transitions(position[0], position[1], direction)
for move_index in range(4):
if moves[move_index]:
valid_movements.append((direction, move_index))
valid_starting_directions = []
for m in valid_movements:
new_position = get_new_position(agents_position[i], m[1])
if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0], agents_target[i]):
valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0:
re_generate = True
else:
agents_direction[i] = valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]]
return agents_position, agents_direction, agents_target
...@@ -11,8 +11,9 @@ import numpy as np ...@@ -11,8 +11,9 @@ import numpy as np
from flatland.core.env import Environment from flatland.core.env import Environment
from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_generators import get_rnd_agents_pos_tgt_dir_on_rail, AgentGenerator
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
from flatland.envs.generators import random_rail_generator from flatland.envs.generators import random_rail_generator, RailGenerator
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
m.patch() m.patch()
...@@ -91,7 +92,8 @@ class RailEnv(Environment): ...@@ -91,7 +92,8 @@ class RailEnv(Environment):
def __init__(self, def __init__(self,
width, width,
height, height,
rail_generator=random_rail_generator(), rail_generator: RailGenerator = random_rail_generator(),
agent_generator: AgentGenerator = get_rnd_agents_pos_tgt_dir_on_rail(),
number_of_agents=1, number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2), obs_builder_object=TreeObsForRailEnv(max_depth=2),
max_episode_steps=None, max_episode_steps=None,
...@@ -107,13 +109,12 @@ class RailEnv(Environment): ...@@ -107,13 +109,12 @@ class RailEnv(Environment):
height and agents handles of a rail environment, along with the number of times height and agents handles of a rail environment, along with the number of times
the env has been reset, and returns a GridTransitionMap object and a list of the env has been reset, and returns a GridTransitionMap object and a list of
starting positions, targets, and initial orientations for agent handle. starting positions, targets, and initial orientations for agent handle.
Implemented functions are: The rail_generator can pass a distance map in the hints or information for specific agent_generators.
random_rail_generator : generate a random rail of given size Implementations can be found in flatland/envs/generators.py
rail_from_grid_transition_map(rail_map) : generate a rail from agent_generator : function
a GridTransitionMap object The agent_generator function is a function that takes the grid, the number of agents and optional hints
rail_from_manual_sp ecifications_generator(rail_spec) : generate a rail from and returns a list of starting positions, targets, initial orientations and speed for all agent handles.
a rail specifications array Implementations can be found in flatland/envs/agent_generators.py
TODO: generate_rail_from_saved_list or from list of ndarray bitmaps ---
width : int width : int
The width of the rail map. Potentially in the future, The width of the rail map. Potentially in the future,
a range of widths to sample from. a range of widths to sample from.
...@@ -131,7 +132,8 @@ class RailEnv(Environment): ...@@ -131,7 +132,8 @@ class RailEnv(Environment):
file_name: you can load a pickle file. file_name: you can load a pickle file.
""" """
self.rail_generator = rail_generator self.rail_generator: RailGenerator = rail_generator
self.agent_generator: AgentGenerator = agent_generator
self.rail = None self.rail = None
self.width = width self.width = width
self.height = height self.height = height
...@@ -213,18 +215,21 @@ class RailEnv(Environment): ...@@ -213,18 +215,21 @@ class RailEnv(Environment):
if replace_agents then regenerate the agents static. if replace_agents then regenerate the agents static.
Relies on the rail_generator returning agent_static lists (pos, dir, target) Relies on the rail_generator returning agent_static lists (pos, dir, target)
""" """
tRailAgents = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets) rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
# Check if generator provided a distance map TODO: Make this check safer! if optionals and 'distance_maps' in optionals:
if len(tRailAgents) > 5: self.obs_builder.distance_map = optionals['distance_maps']
self.obs_builder.distance_map = tRailAgents[-1]
if regen_rail or self.rail is None: if regen_rail or self.rail is None:
self.rail = tRailAgents[0] self.rail = rail
self.height, self.width = self.rail.grid.shape self.height, self.width = self.rail.grid.shape
if replace_agents: if replace_agents:
self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5]) agents_hints = None
if optionals and 'agents_hints' in optionals:
agents_hints = optionals['agents_hints']
self.agents_static = EnvAgentStatic.from_lists(
*self.agent_generator(self.rail, self.get_num_agents(), hints=agents_hints))
self.restart_agents() self.restart_agents()
......
...@@ -2,6 +2,7 @@ import numpy as np ...@@ -2,6 +2,7 @@ import numpy as np
from flatland.core.grid.grid4 import Grid4Transitions from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.transition_map import GridTransitionMap from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_generators import get_rnd_agents_pos_tgt_dir_on_rail
from flatland.envs.generators import rail_from_grid_transition_map from flatland.envs.generators import rail_from_grid_transition_map
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv
...@@ -34,6 +35,7 @@ def test_walker(): ...@@ -34,6 +35,7 @@ def test_walker():
env = RailEnv(width=rail_map.shape[1], env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail), rail_generator=rail_from_grid_transition_map(rail),
agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
number_of_agents=1, number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, obs_builder_object=TreeObsForRailEnv(max_depth=2,
predictor=ShortestPathPredictorForRailEnv(max_depth=10)), predictor=ShortestPathPredictorForRailEnv(max_depth=10)),
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import numpy as np import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.agent_generators import get_rnd_agents_pos_tgt_dir_on_rail
from flatland.envs.agent_utils import EnvAgent from flatland.envs.agent_utils import EnvAgent
from flatland.envs.generators import rail_from_grid_transition_map from flatland.envs.generators import rail_from_grid_transition_map
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
...@@ -21,6 +22,7 @@ def test_global_obs(): ...@@ -21,6 +22,7 @@ def test_global_obs():
env = RailEnv(width=rail_map.shape[1], env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail), rail_generator=rail_from_grid_transition_map(rail),
agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
number_of_agents=1, number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv()) obs_builder_object=GlobalObsForRailEnv())
...@@ -90,6 +92,7 @@ def test_reward_function_conflict(rendering=False): ...@@ -90,6 +92,7 @@ def test_reward_function_conflict(rendering=False):
env = RailEnv(width=rail_map.shape[1], env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail), rail_generator=rail_from_grid_transition_map(rail),
agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
number_of_agents=2, number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
) )
...@@ -168,6 +171,7 @@ def test_reward_function_waiting(rendering=False): ...@@ -168,6 +171,7 @@ def test_reward_function_waiting(rendering=False):
env = RailEnv(width=rail_map.shape[1], env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail), rail_generator=rail_from_grid_transition_map(rail),
agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
number_of_agents=2, number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
) )
......
...@@ -5,6 +5,7 @@ import pprint ...@@ -5,6 +5,7 @@ import pprint
import numpy as np import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.agent_generators import get_rnd_agents_pos_tgt_dir_on_rail
from flatland.envs.generators import rail_from_grid_transition_map from flatland.envs.generators import rail_from_grid_transition_map
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
...@@ -21,6 +22,7 @@ def test_dummy_predictor(rendering=False): ...@@ -21,6 +22,7 @@ def test_dummy_predictor(rendering=False):
env = RailEnv(width=rail_map.shape[1], env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail), rail_generator=rail_from_grid_transition_map(rail),
agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
number_of_agents=1, number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)), obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
) )
...@@ -111,6 +113,7 @@ def test_shortest_path_predictor(rendering=False): ...@@ -111,6 +113,7 @@ def test_shortest_path_predictor(rendering=False):
env = RailEnv(width=rail_map.shape[1], env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail), rail_generator=rail_from_grid_transition_map(rail),
agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
number_of_agents=1, number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
) )
...@@ -230,6 +233,7 @@ def test_shortest_path_predictor_conflicts(rendering=False): ...@@ -230,6 +233,7 @@ def test_shortest_path_predictor_conflicts(rendering=False):
env = RailEnv(width=rail_map.shape[1], env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail), rail_generator=rail_from_grid_transition_map(rail),
agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
number_of_agents=2, number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
) )
......
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
from flatland.core.grid.grid4 import Grid4Transitions from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_generators import get_rnd_agents_pos_tgt_dir_on_rail, complex_rail_generator_agents_placer
from flatland.envs.agent_utils import EnvAgent from flatland.envs.agent_utils import EnvAgent
from flatland.envs.agent_utils import EnvAgentStatic from flatland.envs.agent_utils import EnvAgentStatic
from flatland.envs.generators import complex_rail_generator from flatland.envs.generators import complex_rail_generator
...@@ -27,6 +28,7 @@ def test_load_env(): ...@@ -27,6 +28,7 @@ def test_load_env():
def test_save_load(): def test_save_load():
env = RailEnv(width=10, height=10, env = RailEnv(width=10, height=10,
rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=0), rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=0),
agent_generator=complex_rail_generator_agents_placer(),
number_of_agents=2) number_of_agents=2)
env.reset() env.reset()
agent_1_pos = env.agents_static[0].position agent_1_pos = env.agents_static[0].position
...@@ -86,6 +88,7 @@ def test_rail_environment_single_agent(): ...@@ -86,6 +88,7 @@ def test_rail_environment_single_agent():
rail_env = RailEnv(width=3, rail_env = RailEnv(width=3,
height=3, height=3,
rail_generator=rail_from_grid_transition_map(rail), rail_generator=rail_from_grid_transition_map(rail),
agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
number_of_agents=1, number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv()) obs_builder_object=GlobalObsForRailEnv())
...@@ -165,6 +168,7 @@ def test_dead_end(): ...@@ -165,6 +168,7 @@ def test_dead_end():
rail_env = RailEnv(width=rail_map.shape[1], rail_env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail), rail_generator=rail_from_grid_transition_map(rail),
agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
number_of_agents=1, number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv()) obs_builder_object=GlobalObsForRailEnv())
...@@ -209,6 +213,7 @@ def test_dead_end(): ...@@ -209,6 +213,7 @@ def test_dead_end():
rail_env = RailEnv(width=rail_map.shape[1], rail_env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail), rail_generator=rail_from_grid_transition_map(rail),
agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
number_of_agents=1, number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv()) obs_builder_object=GlobalObsForRailEnv())
......
import numpy as np import numpy as np
from flatland.envs.agent_generators import complex_rail_generator_agents_placer
from flatland.envs.generators import complex_rail_generator from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
...@@ -62,6 +63,7 @@ def test_malfunction_process(): ...@@ -62,6 +63,7 @@ def test_malfunction_process():
height=20, height=20,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
seed=0), seed=0),
agent_generator=complex_rail_generator_agents_placer(),
number_of_agents=2, number_of_agents=2,
obs_builder_object=SingleAgentNavigationObs(), obs_builder_object=SingleAgentNavigationObs(),
stochastic_data=stochastic_data) stochastic_data=stochastic_data)
......
import numpy as np import numpy as np
from flatland.envs.agent_generators import complex_rail_generator_agents_placer
from flatland.envs.generators import complex_rail_generator from flatland.envs.generators import complex_rail_generator
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
np.random.seed(1) np.random.seed(1)
# Use the complex_rail_generator to generate feasible network configurations with corresponding tasks # Use the complex_rail_generator to generate feasible network configurations with corresponding tasks
# Training on simple small tasks is the best way to get familiar with the environment # Training on simple small tasks is the best way to get familiar with the environment
# #
...@@ -46,6 +48,7 @@ def test_multi_speed_init(): ...@@ -46,6 +48,7 @@ def test_multi_speed_init():
height=50, height=50,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999,
seed=0), seed=0),
agent_generator=complex_rail_generator_agents_placer(),
number_of_agents=5) number_of_agents=5)
# Initialize the agent with the parameters corresponding to the environment and observation_builder # Initialize the agent with the parameters corresponding to the environment and observation_builder
agent = RandomAgent(218, 4) agent = RandomAgent(218, 4)
......
"""Test speed initialization by a map of speeds and their corresponding ratios.""" """Test speed initialization by a map of speeds and their corresponding ratios."""
import numpy as np import numpy as np
from flatland.envs.generators import speed_initialization_helper, complex_rail_generator from flatland.envs.agent_generators import speed_initialization_helper, complex_rail_generator_agents_placer
from flatland.envs.generators import complex_rail_generator
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
...@@ -17,13 +18,11 @@ def test_speed_initialization_helper(): ...@@ -17,13 +18,11 @@ def test_speed_initialization_helper():
def test_rail_env_speed_intializer(): def test_rail_env_speed_intializer():
speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.1, 5: 0.2} speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.1, 5: 0.2}
def my_speed_initializer(nb_agents):
return speed_initialization_helper(nb_agents, speed_ratio_map)
env = RailEnv(width=50, env = RailEnv(width=50,
height=50, height=50,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999,
seed=0, speed_initializer=my_speed_initializer), seed=0),
agent_generator=complex_rail_generator_agents_placer(),
number_of_agents=10) number_of_agents=10)
env.reset() env.reset()
actual_speeds = list(map(lambda agent: agent.speed_data['speed'], env.agents)) actual_speeds = list(map(lambda agent: agent.speed_data['speed'], env.agents))
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
import numpy as np import numpy as np
from flatland.envs.agent_generators import get_rnd_agents_pos_tgt_dir_on_rail, complex_rail_generator_agents_placer, \
agents_from_file
from flatland.envs.generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \ from flatland.envs.generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \
random_rail_generator, empty_rail_generator random_rail_generator, empty_rail_generator
from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
...@@ -58,7 +60,8 @@ def test_complex_rail_generator(): ...@@ -58,7 +60,8 @@ def test_complex_rail_generator():
env = RailEnv(width=x_dim, env = RailEnv(width=x_dim,
height=y_dim, height=y_dim,
number_of_agents=n_agents, number_of_agents=n_agents,
rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist) rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist),
agent_generator=complex_rail_generator_agents_placer()
) )
assert env.get_num_agents() == 2 assert env.get_num_agents() == 2
assert env.rail.grid.shape == (y_dim, x_dim) assert env.rail.grid.shape == (y_dim, x_dim)
...@@ -69,7 +72,8 @@ def test_complex_rail_generator(): ...@@ -69,7 +72,8 @@ def test_complex_rail_generator():
env = RailEnv(width=x_dim, env = RailEnv(width=x_dim,
height=y_dim, height=y_dim,
number_of_agents=n_agents, number_of_agents=n_agents,
rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist) rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist),
agent_generator=complex_rail_generator_agents_placer()
) )
assert env.get_num_agents() == 0 assert env.get_num_agents() == 0
assert env.rail.grid.shape == (y_dim, x_dim) assert env.rail.grid.shape == (y_dim, x_dim)
...@@ -82,7 +86,8 @@ def test_complex_rail_generator(): ...@@ -82,7 +86,8 @@ def test_complex_rail_generator():
env = RailEnv(width=x_dim, env = RailEnv(width=x_dim,
height=y_dim, height=y_dim,
number_of_agents=n_agents, number_of_agents=n_agents,
rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist) rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist),
agent_generator=complex_rail_generator_agents_placer()
) )
assert env.get_num_agents() == n_agents assert env.get_num_agents() == n_agents
assert env.rail.grid.shape == (y_dim, x_dim) assert env.rail.grid.shape == (y_dim, x_dim)
...@@ -94,6 +99,7 @@ def test_rail_from_grid_transition_map(): ...@@ -94,6 +99,7 @@ def test_rail_from_grid_transition_map():
env = RailEnv(width=rail_map.shape[1], env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail), rail_generator=rail_from_grid_transition_map(rail),
agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
number_of_agents=n_agents number_of_agents=n_agents
) )
nr_rail_elements = np.count_nonzero(env.rail.grid) nr_rail_elements = np.count_nonzero(env.rail.grid)
...@@ -118,6 +124,7 @@ def tests_rail_from_file(): ...@@ -118,6 +124,7 @@ def tests_rail_from_file():
env = RailEnv(width=rail_map.shape[1], env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail), rail_generator=rail_from_grid_transition_map(rail),
agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
number_of_agents=3, number_of_agents=3,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
) )
...@@ -130,6 +137,7 @@ def tests_rail_from_file(): ...@@ -130,6 +137,7 @@ def tests_rail_from_file():
env = RailEnv(width=1, env = RailEnv(width=1,
height=1, height=1,
rail_generator=rail_from_file(file_name), rail_generator=rail_from_file(file_name),
agent_generator=agents_from_file(file_name),
number_of_agents=1, number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
) )
...@@ -151,6 +159,7 @@ def tests_rail_from_file(): ...@@ -151,6 +159,7 @@ def tests_rail_from_file():
env2 = RailEnv(width=rail_map.shape[1], env2 = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail), rail_generator=rail_from_grid_transition_map(rail),
agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
number_of_agents=3, number_of_agents=3,
obs_builder_object=GlobalObsForRailEnv(), obs_builder_object=GlobalObsForRailEnv(),
) )
...@@ -164,6 +173,7 @@ def tests_rail_from_file(): ...@@ -164,6 +173,7 @@ def tests_rail_from_file():
env2 = RailEnv(width=1, env2 = RailEnv(width=1,
height=1, height=1,
rail_generator=rail_from_file(file_name_2), rail_generator=rail_from_file(file_name_2),
agent_generator=agents_from_file(file_name_2),
number_of_agents=1, number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv(), obs_builder_object=GlobalObsForRailEnv(),
) )
...@@ -180,6 +190,7 @@ def tests_rail_from_file(): ...@@ -180,6 +190,7 @@ def tests_rail_from_file():
env3 = RailEnv(width=1, env3 = RailEnv(width=1,
height=1, height=1,
rail_generator=rail_from_file(file_name), rail_generator=rail_from_file(file_name),
agent_generator=agents_from_file(file_name),
number_of_agents=1, number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv(), obs_builder_object=GlobalObsForRailEnv(),
) )
...@@ -197,6 +208,7 @@ def tests_rail_from_file(): ...@@ -197,6 +208,7 @@ def tests_rail_from_file():
env4 = RailEnv(width=1, env4 = RailEnv(width=1,
height=1, height=1,
rail_generator=rail_from_file(file_name_2), rail_generator=rail_from_file(file_name_2),
agent_generator=agents_from_file(file_name_2),
number_of_agents=1, number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2), obs_builder_object=TreeObsForRailEnv(max_depth=2),
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment