Skip to content
Snippets Groups Projects
Commit e5606f1e authored by u214892's avatar u214892
Browse files

#141 different agent classes

parent dece6c16
No related branches found
No related tags found
No related merge requests found
import random
from typing import Any
import numpy as np
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_generators import AgentGenerator, AgentGeneratorProduct
from flatland.envs.generators import RailGenerator, RailGeneratorProduct
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
......@@ -11,20 +14,28 @@ random.seed(100)
np.random.seed(100)
def custom_rail_generator():
def generator(width, height, num_agents=0, num_resets=0):
def custom_rail_generator() -> RailGenerator:
def generator(width: int, height: int, num_agents: int = 0, num_resets: int = 0) -> RailGeneratorProduct:
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
rail_array = grid_map.grid
rail_array.fill(0)
new_tran = rail_trans.set_transition(1, 1, 1, 1)
print(new_tran)
rail_array[0, 0] = new_tran
rail_array[0, 1] = new_tran
return grid_map, None
return generator
def custom_agent_generator() -> AgentGenerator:
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> AgentGeneratorProduct:
agents_positions = []
agents_direction = []
agents_target = []
rail_array[0, 0] = new_tran
rail_array[0, 1] = new_tran
return grid_map, agents_positions, agents_direction, agents_target
speeds = []
return agents_positions, agents_direction, agents_target, speeds
return generator
......
......@@ -73,7 +73,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(speed_ratio_map: Mapping[float, float] =
initial positions, directions, targets speeds
"""
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None):
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> AgentGeneratorProduct:
def _path_exists(rail, start, direction, end):
# BFS - Check if a path exists between the 2 nodes
......@@ -165,7 +165,7 @@ def agents_from_file(filename) -> AgentGenerator:
initial positions, directions, targets speeds
"""
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None):
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> AgentGeneratorProduct:
with open(filename, "rb") as file_in:
load_data = file_in.read()
data = msgpack.unpackb(load_data, use_list=False)
......
......@@ -10,7 +10,8 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.grid4_generators_utils import connect_rail
RailGenerator = Callable[[int, int, int, int], Tuple[GridTransitionMap, Optional[Any]]]
RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Any]]
RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
def empty_rail_generator() -> RailGenerator:
......@@ -19,13 +20,13 @@ def empty_rail_generator() -> RailGenerator:
Primarily used by the editor
"""
def generator(width, height, num_agents=0, num_resets=0):
def generator(width: int, height: int, num_agents: int = 0, num_resets: int = 0) -> RailGeneratorProduct:
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
rail_array = grid_map.grid
rail_array.fill(0)
return [grid_map, None]
return grid_map, None
return generator
......@@ -249,8 +250,8 @@ def rail_from_grid_transition_map(rail_map) -> RailGenerator:
Generator function that always returns the given `rail_map' object.
"""
def generator(width, height, num_agents, num_resets=0):
return [rail_map, None]
def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct:
return rail_map, None
return generator
......@@ -287,7 +288,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGener
The matrix with the correct 16-bit bitmaps for each cell.
"""
def generator(width, height, num_agents, num_resets=0):
def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct:
t_utils = RailEnvTransitions()
transition_probability = cell_type_relative_proportion
......@@ -519,6 +520,6 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGener
return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
return_rail.grid = tmp_rail
return [return_rail, None]
return return_rail, None
return generator
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment