Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • flatland/flatland
  • stefan_otte/flatland
  • jiaodaxiaozi/flatland
  • sfwatergit/flatland
  • utozx126/flatland
  • ChenKuanSun/flatland
  • ashivani/flatland
  • minhhoa/flatland
  • pranjal_dhole/flatland
  • darthgera123/flatland
  • rivesunder/flatland
  • thomaslecat/flatland
  • joel_joseph/flatland
  • kchour/flatland
  • alex_zharichenko/flatland
  • yoogottamk/flatland
  • troye_fang/flatland
  • elrichgro/flatland
  • jun_jin/flatland
  • nimishsantosh107/flatland
20 results
Show changes
Showing
with 2033 additions and 1685 deletions
from typing import Tuple
# Adrian Egli / Michel Marti performance fix (the fast methods brings more than 50%)
def fast_isclose(a, b, rtol):
return (a < (b + rtol)) or (a < (b - rtol))
def fast_clip(position: Tuple[int, int], min_value: Tuple[int, int], max_value: Tuple[int, int]) -> bool:
return (
max(min_value[0], min(position[0], max_value[0])),
max(min_value[1], min(position[1], max_value[1]))
)
def fast_argmax(possible_transitions: (int, int, int, int)) -> bool:
if possible_transitions[0] == 1:
return 0
if possible_transitions[1] == 1:
return 1
if possible_transitions[2] == 1:
return 2
return 3
def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
if pos_1 is None and pos_2 is None:
return True
if pos_1 is None or pos_2 is None:
return False
return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
def fast_count_nonzero(possible_transitions: (int, int, int, int)):
return possible_transitions[0] + possible_transitions[1] + possible_transitions[2] + possible_transitions[3]
def fast_delete(lis: list, index) -> list:
new_list = lis.copy()
new_list.pop(index)
return new_list
def fast_where(binary_iterable):
return [index for index, element in enumerate(binary_iterable) if element != 0]
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file
from flatland.envs.schedule_generators import schedule_from_file
def load_flatland_environment_from_file(file_name: str,
load_from_package: str = None,
obs_builder_object: ObservationBuilder = None) -> RailEnv:
"""
Parameters
----------
file_name : str
The pickle file.
load_from_package : str
The python module to import from. Example: 'env_data.tests'
This requires that there are `__init__.py` files in the folder structure we load the file from.
obs_builder_object: ObservationBuilder
The obs builder for the `RailEnv` that is created.
Returns
-------
RailEnv
The environment loaded from the pickle file.
"""
if obs_builder_object is None:
obs_builder_object = TreeObsForRailEnv(
max_depth=2,
predictor=ShortestPathPredictorForRailEnv(max_depth=10))
environment = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name, load_from_package),
schedule_generator=schedule_from_file(file_name, load_from_package), number_of_agents=1,
obs_builder_object=obs_builder_object)
return environment
......@@ -160,6 +160,7 @@ def fix_inner_nodes(grid_map: GridTransitionMap, inner_node_pos: IntVector2D, ra
grid_map.grid[tmp_pos] = transition
return
def align_cell_to_city(city_center, city_orientation, cell):
"""
Alig all cells to face the city center along the city orientation
......@@ -171,4 +172,4 @@ def align_cell_to_city(city_center, city_orientation, cell):
if city_orientation % 2 == 0:
return int(2 * np.clip(cell[0] - city_center[0], 0, 1))
else:
return int(2 * np.clip(city_center[1] - cell[1], 0, 1)) + 1
return int(2 * np.clip(city_center[1] - cell[1], 0, 1)) + 1
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -5,11 +5,12 @@ Collection of environment-specific PredictionBuilder.
import numpy as np
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.distance_map import DistanceMap
from flatland.envs.rail_env import RailEnvActions
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
from flatland.utils.ordered_set import OrderedSet
from flatland.envs.step_utils.states import TrainState
from flatland.envs.step_utils import transition_utils
class DummyPredictorForRailEnv(PredictionBuilder):
......@@ -48,7 +49,7 @@ class DummyPredictorForRailEnv(PredictionBuilder):
prediction_dict = {}
for agent in agents:
if agent.status != RailAgentStatus.ACTIVE:
if not agent.state.is_on_map_state():
# TODO make this generic
continue
action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
......@@ -64,8 +65,8 @@ class DummyPredictorForRailEnv(PredictionBuilder):
continue
for action in action_priorities:
cell_is_free, new_cell_isValid, new_direction, new_position, transition_isValid = \
self.env._check_action_on_agent(action, agent)
new_cell_isValid, new_direction, new_position, transition_isValid = \
transition_utils.check_action_on_agent(action, self.env.rail, agent.position, agent.direction)
if all([new_cell_isValid, transition_isValid]):
# move and change direction to face the new_direction that was
# performed
......@@ -126,12 +127,11 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
prediction_dict = {}
for agent in agents:
if agent.status == RailAgentStatus.READY_TO_DEPART:
if agent.state.is_off_map_state():
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
elif agent.state.is_on_map_state():
agent_virtual_position = agent.position
elif agent.status == RailAgentStatus.DONE:
elif agent.state == TrainState.DONE:
agent_virtual_position = agent.target
else:
......@@ -142,7 +142,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
continue
agent_virtual_direction = agent.direction
agent_speed = agent.speed_data["speed"]
agent_speed = agent.speed_counter.speed
times_per_cell = int(np.reciprocal(agent_speed))
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *agent_virtual_position, agent_virtual_direction, 0]
......
This diff is collapsed.
from enum import IntEnum
from typing import NamedTuple
from flatland.core.grid.grid4 import Grid4TransitionsEnum
class RailEnvActions(IntEnum):
DO_NOTHING = 0 # implies change of direction in a dead-end!
MOVE_LEFT = 1
MOVE_FORWARD = 2
MOVE_RIGHT = 3
STOP_MOVING = 4
@staticmethod
def to_char(a: int):
return {
0: 'B',
1: 'L',
2: 'F',
3: 'R',
4: 'S',
}[a]
@classmethod
def is_action_valid(cls, action):
return action in cls._value2member_map_
def is_moving_action(self):
return self.value in [self.MOVE_RIGHT, self.MOVE_LEFT, self.MOVE_FORWARD]
RailEnvGridPos = NamedTuple('RailEnvGridPos', [('r', int), ('c', int)])
RailEnvNextAction = NamedTuple('RailEnvNextAction', [('action', RailEnvActions), ('next_position', RailEnvGridPos),
('next_direction', Grid4TransitionsEnum)])
This diff is collapsed.
......@@ -3,12 +3,14 @@ from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file
from flatland.envs.schedule_generators import schedule_from_file
from flatland.envs.line_generators import line_from_file
def load_flatland_environment_from_file(file_name: str,
load_from_package: str = None,
obs_builder_object: ObservationBuilder = None) -> RailEnv:
obs_builder_object: ObservationBuilder = None,
record_steps = False,
) -> RailEnv:
"""
Parameters
----------
......@@ -30,10 +32,10 @@ def load_flatland_environment_from_file(file_name: str,
obs_builder_object = TreeObsForRailEnv(
max_depth=2,
predictor=ShortestPathPredictorForRailEnv(max_depth=10))
environment = RailEnv(width=1, # will be overridden when loading from file
height=1, # will be overridden when loading from file
rail_generator=rail_from_file(file_name, load_from_package),
number_of_agents=1, # will be overridden when loading from file
schedule_generator=schedule_from_file(file_name, load_from_package),
obs_builder_object=obs_builder_object)
environment = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name, load_from_package),
line_generator=line_from_file(file_name, load_from_package),
number_of_agents=1,
obs_builder_object=obs_builder_object,
record_steps=record_steps,
)
return environment
This diff is collapsed.
from typing import NamedTuple, Tuple, List, Dict
# A way point is the entry into a cell defined by
# - the row and column coordinates of the cell entered
# - direction, in which the agent is facing to enter the cell.
# This induces a graph on top of the FLATland cells:
# - four possible way points per cell
# - edges are the possible transitions in the cell.
Waypoint = NamedTuple('Waypoint', [('position', Tuple[int, int]), ('direction', int)])
# A train run is represented by the waypoints traversed and the times of traversal
# The terminology follows https://github.com/crowdAI/train-schedule-optimisation-challenge-starter-kit/blob/master/documentation/output_data_model.md
TrainrunWaypoint = NamedTuple('TrainrunWaypoint', [
('scheduled_at', int),
('waypoint', Waypoint)
])
# A train run is the list of an agent's way points and their scheduled time
Trainrun = List[TrainrunWaypoint]
TrainrunDict = Dict[int, Trainrun]
This diff is collapsed.
from typing import List, NamedTuple
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid_utils import IntVector2DArray
Schedule = NamedTuple('Schedule', [('agent_positions', IntVector2DArray),
('agent_directions', List[Grid4TransitionsEnum]),
('agent_targets', IntVector2DArray),
('agent_speeds', List[float]),
('agent_malfunction_rates', List[int])])
"""Rail generators (infrastructure manager, "Infrastrukturbetreiber")."""
import sys
import warnings
from typing import Callable, Tuple, Optional, Dict, List
import numpy as np
from numpy.random.mtrand import RandomState
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_direction, mirror, direction_to_point
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.grid.grid_utils import distance_on_rail, IntVector2DArray, IntVector2D, \
Vec2dOperations
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map, connect_straight_line_in_grid_map, \
fix_inner_nodes, align_cell_to_city
from flatland.envs import persistence
from flatland.envs.rail_generators import RailGeneratorProduct, RailGenerator
This diff is collapsed.
This diff is collapsed.