From e5606f1edab6f6ceb459ffcd78194a664e9c74ce Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Tue, 27 Aug 2019 12:37:52 +0200 Subject: [PATCH] #141 different agent classes --- examples/custom_railmap_example.py | 21 ++++++++++++++++----- flatland/envs/agent_generators.py | 4 ++-- flatland/envs/generators.py | 15 ++++++++------- 3 files changed, 26 insertions(+), 14 deletions(-) diff --git a/examples/custom_railmap_example.py b/examples/custom_railmap_example.py index 515d6c1b..f6bd2bda 100644 --- a/examples/custom_railmap_example.py +++ b/examples/custom_railmap_example.py @@ -1,9 +1,12 @@ import random +from typing import Any import numpy as np from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap +from flatland.envs.agent_generators import AgentGenerator, AgentGeneratorProduct +from flatland.envs.generators import RailGenerator, RailGeneratorProduct from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool @@ -11,20 +14,28 @@ random.seed(100) np.random.seed(100) -def custom_rail_generator(): - def generator(width, height, num_agents=0, num_resets=0): +def custom_rail_generator() -> RailGenerator: + def generator(width: int, height: int, num_agents: int = 0, num_resets: int = 0) -> RailGeneratorProduct: rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) rail_array = grid_map.grid rail_array.fill(0) new_tran = rail_trans.set_transition(1, 1, 1, 1) print(new_tran) + rail_array[0, 0] = new_tran + rail_array[0, 1] = new_tran + return grid_map, None + + return generator + + +def custom_agent_generator() -> AgentGenerator: + def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> AgentGeneratorProduct: agents_positions = [] agents_direction = [] agents_target = [] - rail_array[0, 0] = new_tran - rail_array[0, 1] = new_tran - return grid_map, agents_positions, agents_direction, agents_target + speeds = [] + return agents_positions, agents_direction, agents_target, speeds return generator diff --git a/flatland/envs/agent_generators.py b/flatland/envs/agent_generators.py index 1f769b7d..c03511bc 100644 --- a/flatland/envs/agent_generators.py +++ b/flatland/envs/agent_generators.py @@ -73,7 +73,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(speed_ratio_map: Mapping[float, float] = initial positions, directions, targets speeds """ - def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None): + def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> AgentGeneratorProduct: def _path_exists(rail, start, direction, end): # BFS - Check if a path exists between the 2 nodes @@ -165,7 +165,7 @@ def agents_from_file(filename) -> AgentGenerator: initial positions, directions, targets speeds """ - def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None): + def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> AgentGeneratorProduct: with open(filename, "rb") as file_in: load_data = file_in.read() data = msgpack.unpackb(load_data, use_list=False) diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 380bf37f..5e97f158 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -10,7 +10,8 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap from flatland.envs.grid4_generators_utils import connect_rail -RailGenerator = Callable[[int, int, int, int], Tuple[GridTransitionMap, Optional[Any]]] +RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Any]] +RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct] def empty_rail_generator() -> RailGenerator: @@ -19,13 +20,13 @@ def empty_rail_generator() -> RailGenerator: Primarily used by the editor """ - def generator(width, height, num_agents=0, num_resets=0): + def generator(width: int, height: int, num_agents: int = 0, num_resets: int = 0) -> RailGeneratorProduct: rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) rail_array = grid_map.grid rail_array.fill(0) - return [grid_map, None] + return grid_map, None return generator @@ -249,8 +250,8 @@ def rail_from_grid_transition_map(rail_map) -> RailGenerator: Generator function that always returns the given `rail_map' object. """ - def generator(width, height, num_agents, num_resets=0): - return [rail_map, None] + def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct: + return rail_map, None return generator @@ -287,7 +288,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGener The matrix with the correct 16-bit bitmaps for each cell. """ - def generator(width, height, num_agents, num_resets=0): + def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct: t_utils = RailEnvTransitions() transition_probability = cell_type_relative_proportion @@ -519,6 +520,6 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGener return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils) return_rail.grid = tmp_rail - return [return_rail, None] + return return_rail, None return generator -- GitLab