Commit 52bfe623 authored by Erik Nygren's avatar Erik Nygren 🚅
Browse files

code cleanup and test update

parent a6153ee1
Pipeline #2536 passed with stages
in 43 minutes and 38 seconds
......@@ -31,7 +31,8 @@ def connect_rail_in_grid_map(grid_map: GridTransitionMap, start: IntVector2D, en
:param end: end position of rail
:param flip_start_node_trans: make valid start position by adding dead-end, empty start if False
:param flip_end_node_trans: make valid end position by adding dead-end, empty end if False
:param respect_transition_validity: Only draw rail maps if legal rail elements can be use, False, draw line without respecting rail transitions.
:param respect_transition_validity: Only draw rail maps if legal rail elements can be use, False, draw line without
respecting rail transitions.
:param a_star_distance_function: Define what distance function a-star should use
:param forbidden_cells: cells to avoid when drawing rail. Rail cannot go through this list of cells
:return: List of cells in the path
......
......@@ -579,7 +579,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
# for r in range(self.env.height):
# for c in range(self.env.width):
# obs_agents_state[(r, c)][4] = 0
obs_agents_state[:,:,4] = 0
obs_agents_state[:, :, 4] = 0
obs_agents_state[agent_virtual_position][0] = agent.direction
obs_targets[agent.target][0] = 1
......
......@@ -269,9 +269,9 @@ class RailEnv(Environment):
self.rail = rail
self.height, self.width = self.rail.grid.shape
# Do a new set_env call on the obs_builder to ensure
# that obs_builder specific instantiations are made according to the
# that obs_builder specific instantiations are made according to the
# specifications of the current environment : like width, height, etc
self.obs_builder.set_env(self)
......@@ -750,7 +750,7 @@ class RailEnv(Environment):
print("[WARNING] Unable to save the distance map for this environment, as none was found !")
else:
with open(filename,"wb") as file_out:
with open(filename, "wb") as file_out:
file_out.write(self.get_full_state_msg())
def load(self, filename):
......
"""Rail generators (infrastructure manager, "Infrastrukturbetreiber")."""
import sys
import warnings
from typing import Callable, Tuple, Optional, Dict, List
import msgpack
import numpy as np
import sys
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_direction, mirror, direction_to_point
......@@ -569,14 +569,15 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
# Calculate the max number of cities allowed
# and reduce the number of cities to build to avoid problems
max_feasible_cities = min(max_num_cities, ((height - 2) // (2 * (city_radius + 1))) * ((width - 2) // (2 * (city_radius + 1))))
max_feasible_cities = min(max_num_cities,
((height - 2) // (2 * (city_radius + 1))) * ((width - 2) // (2 * (city_radius + 1))))
if max_feasible_cities < 2:
sys.exit("Cannot fit more than one city in this map, no feasible environment possible! Aborting.")
# Evenly distribute cities
if grid_mode:
city_positions = _generate_evenly_distr_city_positions(max_feasible_cities, city_radius, width,
height)
height)
else:
city_positions = _generate_random_city_positions(max_feasible_cities, city_radius, width, height)
......@@ -586,10 +587,9 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
# Try with evenly distributed cities
if num_cities < 2:
city_positions = _generate_evenly_distr_city_positions(max_feasible_cities, city_radius, width,
height)
height)
num_cities = len(city_positions)
# Set up connection points for all cities
inner_connection_points, outer_connection_points, connection_info, city_orientations, city_cells = \
_generate_city_connection_points(
......@@ -665,7 +665,8 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
city_positions.append((row, col))
return city_positions
def _generate_city_connection_points(city_positions: IntVector2DArray, city_radius: int, vector_field: IntVector2DArray, rails_between_cities: int,
def _generate_city_connection_points(city_positions: IntVector2DArray, city_radius: int,
vector_field: IntVector2DArray, rails_between_cities: int,
rails_in_city: int = 2) -> (List[List[List[IntVector2D]]],
List[List[List[IntVector2D]]],
List[np.ndarray],
......@@ -948,7 +949,8 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
# http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python
return sorted(range(len(seq)), key=seq.__getitem__)
def _get_cells_in_city(center: IntVector2D, radius: int, city_orientation: int, vector_field: IntVector2DArray) -> IntVector2DArray:
def _get_cells_in_city(center: IntVector2D, radius: int, city_orientation: int,
vector_field: IntVector2DArray) -> IntVector2DArray:
"""
Parameters
......
......@@ -65,7 +65,6 @@ def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None, se
def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, seed: int = 1) -> ScheduleGenerator:
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0):
_runtime_seed = seed + num_resets
......@@ -105,7 +104,8 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, see
while target[1] % 2 != 1:
target_idx = np.random.choice(np.arange(len(train_stations[target_city])))
target = train_stations[target_city][target_idx]
possible_orientations = [agent_start_targets_cities[city_idx][2], (agent_start_targets_cities[city_idx][2] + 2) % 4 ]
possible_orientations = [agent_start_targets_cities[city_idx][2],
(agent_start_targets_cities[city_idx][2] + 2) % 4]
agent_orientation = np.random.choice(possible_orientations)
if not rail.check_path_exists(start[0], agent_orientation, target[0]):
agent_orientation = (agent_orientation + 2) % 4
......@@ -147,7 +147,7 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] =
"""
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None,
num_resets: int = 0) -> ScheduleGeneratorProduct:
num_resets: int = 0) -> ScheduleGeneratorProduct:
_runtime_seed = seed + num_resets
np.random.seed(_runtime_seed)
......@@ -268,4 +268,3 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator:
return agents_position, agents_direction, agents_target, agents_speed, agents_malfunction
return generator
......@@ -547,7 +547,7 @@ def tests_random_interference_from_outside():
_, reward, _, _ = env.step(action_dict)
# Append the rewards of the first trial
env_data.append((reward[0],env.agents[0].position))
env_data.append((reward[0], env.agents[0].position))
assert reward[0] == env_data[step][0]
assert env.agents[0].position == env_data[step][1]
# Run the same test as above but with an external random generator running
......@@ -570,7 +570,6 @@ def tests_random_interference_from_outside():
env.agents[0].target = (3, 9)
env.reset(False, False, False)
# Print for test generation
dummy_list = [1, 2, 6, 7, 8, 9, 4, 5, 4]
for step in range(200):
......@@ -580,8 +579,8 @@ def tests_random_interference_from_outside():
action_dict[agent.handle] = RailEnvActions(2)
# Do dummy random number generations
a = random.shuffle(dummy_list)
b = np.random.rand()
random.shuffle(dummy_list)
np.random.rand()
_, reward, _, _ = env.step(action_dict)
assert reward[0] == env_data[step][0]
......
import numpy as np
import random
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment