Skip to content
Snippets Groups Projects
Commit bbb0eff6 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

Merge branch 'env-step-facelift' into 'env-step-facelift'

Env step facelift

See merge request flatland/flatland!320
parents 34fa9914 2cc768c4
No related branches found
No related tags found
No related merge requests found
Showing with 820 additions and 664 deletions
...@@ -2,21 +2,19 @@ from flatland.envs.rail_trainrun_data_structures import Waypoint ...@@ -2,21 +2,19 @@ from flatland.envs.rail_trainrun_data_structures import Waypoint
import numpy as np import numpy as np
from enum import IntEnum from enum import IntEnum
from flatland.envs.step_utils.states import TrainState
from itertools import starmap from itertools import starmap
from typing import Tuple, Optional, NamedTuple, List from typing import Tuple, Optional, NamedTuple, List
from attr import attr, attrs, attrib, Factory from attr import attr, attrs, attrib, Factory
from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.timetable_utils import Line from flatland.envs.schedule_utils import Schedule
class RailAgentStatus(IntEnum):
WAITING = 0
READY_TO_DEPART = 1 # not in grid yet (position is None) -> prediction as if it were at initial position
ACTIVE = 2 # in grid (position is not None), not done -> prediction is remaining path
DONE = 3 # in grid (position is not None), but done -> prediction is stay at target forever
DONE_REMOVED = 4 # removed from grid (position is None) -> prediction is None
from flatland.envs.step_utils.action_saver import ActionSaver
from flatland.envs.step_utils.speed_counter import SpeedCounter
from flatland.envs.step_utils.state_machine import TrainStateMachine
from flatland.envs.step_utils.malfunction_handler import MalfunctionHandler
Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]), Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
('initial_direction', Grid4TransitionsEnum), ('initial_direction', Grid4TransitionsEnum),
...@@ -28,11 +26,16 @@ Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]), ...@@ -28,11 +26,16 @@ Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
('speed_data', dict), ('speed_data', dict),
('malfunction_data', dict), ('malfunction_data', dict),
('handle', int), ('handle', int),
('status', RailAgentStatus),
('position', Tuple[int, int]), ('position', Tuple[int, int]),
('arrival_time', int), ('arrival_time', int),
('old_direction', Grid4TransitionsEnum), ('old_direction', Grid4TransitionsEnum),
('old_position', Tuple[int, int])]) ('old_position', Tuple[int, int]),
('speed_counter', SpeedCounter),
('action_saver', ActionSaver),
('state', TrainState),
('state_machine', TrainStateMachine),
('malfunction_handler', MalfunctionHandler),
])
@attrs @attrs
...@@ -65,7 +68,15 @@ class EnvAgent: ...@@ -65,7 +68,15 @@ class EnvAgent:
handle = attrib(default=None) handle = attrib(default=None)
# INIT TILL HERE IN _from_line() # INIT TILL HERE IN _from_line()
status = attrib(default=RailAgentStatus.WAITING, type=RailAgentStatus) # Env step facelift
speed_counter = attrib(default = None, type=SpeedCounter)
action_saver = attrib(default = Factory(lambda: ActionSaver()), type=ActionSaver)
state_machine = attrib(default= Factory(lambda: TrainStateMachine(initial_state=TrainState.WAITING)) ,
type=TrainStateMachine)
malfunction_handler = attrib(default = Factory(lambda: MalfunctionHandler()), type=MalfunctionHandler)
state = attrib(default=TrainState.WAITING, type=TrainState)
position = attrib(default=None, type=Optional[Tuple[int, int]]) position = attrib(default=None, type=Optional[Tuple[int, int]])
# NEW : EnvAgent Reward Handling # NEW : EnvAgent Reward Handling
...@@ -75,6 +86,7 @@ class EnvAgent: ...@@ -75,6 +86,7 @@ class EnvAgent:
old_direction = attrib(default=None) old_direction = attrib(default=None)
old_position = attrib(default=None) old_position = attrib(default=None)
def reset(self): def reset(self):
""" """
Resets the agents to their initial values of the episode. Called after ScheduleTime generation. Resets the agents to their initial values of the episode. Called after ScheduleTime generation.
...@@ -82,14 +94,6 @@ class EnvAgent: ...@@ -82,14 +94,6 @@ class EnvAgent:
self.position = None self.position = None
# TODO: set direction to None: https://gitlab.aicrowd.com/flatland/flatland/issues/280 # TODO: set direction to None: https://gitlab.aicrowd.com/flatland/flatland/issues/280
self.direction = self.initial_direction self.direction = self.initial_direction
if (self.earliest_departure == 0):
self.status = RailAgentStatus.READY_TO_DEPART
else:
self.status = RailAgentStatus.WAITING
self.arrival_time = None
self.old_position = None self.old_position = None
self.old_direction = None self.old_direction = None
self.moving = False self.moving = False
...@@ -103,48 +107,42 @@ class EnvAgent: ...@@ -103,48 +107,42 @@ class EnvAgent:
self.malfunction_data['nr_malfunctions'] = 0 self.malfunction_data['nr_malfunctions'] = 0
self.malfunction_data['moving_before_malfunction'] = False self.malfunction_data['moving_before_malfunction'] = False
# NEW : Callables self.action_saver.clear_saved_action()
def get_shortest_path(self, distance_map) -> List[Waypoint]: self.speed_counter.reset_counter()
from flatland.envs.rail_env_shortest_paths import get_shortest_paths # Circular dep fix self.state_machine.reset()
return get_shortest_paths(distance_map=distance_map, agent_handle=self.handle)[self.handle]
def get_travel_time_on_shortest_path(self, distance_map) -> int:
shortest_path = self.get_shortest_path(distance_map)
if shortest_path is not None:
distance = len(shortest_path)
else:
distance = 0
speed = self.speed_data['speed']
return int(np.ceil(distance / speed))
def get_time_remaining_until_latest_arrival(self, elapsed_steps: int) -> int:
return self.latest_arrival - elapsed_steps
def get_current_delay(self, elapsed_steps: int, distance_map) -> int:
'''
+ve if arrival time is projected before latest arrival
-ve if arrival time is projected after latest arrival
'''
return self.get_time_remaining_until_latest_arrival(elapsed_steps) - \
self.get_travel_time_on_shortest_path(distance_map)
def to_agent(self) -> Agent: def to_agent(self) -> Agent:
return Agent(initial_position=self.initial_position, initial_direction=self.initial_direction, return Agent(initial_position=self.initial_position,
direction=self.direction, target=self.target, moving=self.moving, earliest_departure=self.earliest_departure, initial_direction=self.initial_direction,
latest_arrival=self.latest_arrival, speed_data=self.speed_data, malfunction_data=self.malfunction_data, direction=self.direction,
handle=self.handle, status=self.status, position=self.position, arrival_time=self.arrival_time, target=self.target,
old_direction=self.old_direction, old_position=self.old_position) moving=self.moving,
earliest_departure=self.earliest_departure,
latest_arrival=self.latest_arrival,
speed_data=self.speed_data,
malfunction_data=self.malfunction_data,
handle=self.handle,
state=self.state,
position=self.position,
old_direction=self.old_direction,
old_position=self.old_position,
speed_counter=self.speed_counter,
action_saver=self.action_saver,
state_machine=self.state_machine,
malfunction_handler=self.malfunction_handler)
@classmethod @classmethod
def from_line(cls, line: Line): def from_line(cls, line: Line):
""" Create a list of EnvAgent from lists of positions, directions and targets """ Create a list of EnvAgent from lists of positions, directions and targets
""" """
speed_datas = [] speed_datas = []
speed_counters = []
for i in range(len(line.agent_positions)): for i in range(len(schedule.agent_positions)):
speed = schedule.agent_speeds[i] if schedule.agent_speeds is not None else 1.0
speed_datas.append({'position_fraction': 0.0, speed_datas.append({'position_fraction': 0.0,
'speed': line.agent_speeds[i] if line.agent_speeds is not None else 1.0, 'speed': speed,
'transition_action_on_cellexit': 0}) 'transition_action_on_cellexit': 0})
speed_counters.append( SpeedCounter(speed=speed) )
malfunction_datas = [] malfunction_datas = []
for i in range(len(line.agent_positions)): for i in range(len(line.agent_positions)):
...@@ -153,17 +151,19 @@ class EnvAgent: ...@@ -153,17 +151,19 @@ class EnvAgent:
i] if line.agent_malfunction_rates is not None else 0., i] if line.agent_malfunction_rates is not None else 0.,
'next_malfunction': 0, 'next_malfunction': 0,
'nr_malfunctions': 0}) 'nr_malfunctions': 0})
return list(starmap(EnvAgent, zip(line.agent_positions, return list(starmap(EnvAgent, zip(schedule.agent_positions, # TODO : Dipam - Really want to change this way of loading agents
line.agent_directions, schedule.agent_directions,
line.agent_directions, schedule.agent_directions,
line.agent_targets, schedule.agent_targets,
[False] * len(line.agent_positions), [False] * len(schedule.agent_positions),
[None] * len(line.agent_positions), # earliest_departure [None] * len(schedule.agent_positions), # earliest_departure
[None] * len(line.agent_positions), # latest_arrival [None] * len(schedule.agent_positions), # latest_arrival
speed_datas, speed_datas,
malfunction_datas, malfunction_datas,
range(len(line.agent_positions))))) range(len(schedule.agent_positions)),
speed_counters,
)))
@classmethod @classmethod
def load_legacy_static_agent(cls, static_agents_data: Tuple): def load_legacy_static_agent(cls, static_agents_data: Tuple):
......
...@@ -18,7 +18,7 @@ MalfunctionProcessData = NamedTuple('MalfunctionProcessData', ...@@ -18,7 +18,7 @@ MalfunctionProcessData = NamedTuple('MalfunctionProcessData',
Malfunction = NamedTuple('Malfunction', [('num_broken_steps', int)]) Malfunction = NamedTuple('Malfunction', [('num_broken_steps', int)])
# Why is the return value Optional? We always return a Malfunction. # Why is the return value Optional? We always return a Malfunction.
MalfunctionGenerator = Callable[[EnvAgent, RandomState, bool], Optional[Malfunction]] MalfunctionGenerator = Callable[[RandomState, bool], Malfunction]
def _malfunction_prob(rate: float) -> float: def _malfunction_prob(rate: float) -> float:
""" """
...@@ -42,21 +42,14 @@ class ParamMalfunctionGen(object): ...@@ -42,21 +42,14 @@ class ParamMalfunctionGen(object):
#self.max_number_of_steps_broken = parameters.max_duration #self.max_number_of_steps_broken = parameters.max_duration
self.MFP = parameters self.MFP = parameters
def generate(self, def generate(self, np_random: RandomState) -> Malfunction:
agent: EnvAgent = None,
np_random: RandomState = None,
reset=False) -> Optional[Malfunction]:
# Dummy reset function as we don't implement specific seeding here
if reset:
return Malfunction(0)
if agent.malfunction_data['malfunction'] < 1: if np_random.rand() < _malfunction_prob(self.MFP.malfunction_rate):
if np_random.rand() < _malfunction_prob(self.MFP.malfunction_rate): num_broken_steps = np_random.randint(self.MFP.min_duration,
num_broken_steps = np_random.randint(self.MFP.min_duration, self.MFP.max_duration + 1) + 1
self.MFP.max_duration + 1) + 1 else:
return Malfunction(num_broken_steps) num_broken_steps = 0
return Malfunction(0) return Malfunction(num_broken_steps)
def get_process_data(self): def get_process_data(self):
return MalfunctionProcessData(*self.MFP) return MalfunctionProcessData(*self.MFP)
...@@ -103,7 +96,7 @@ def no_malfunction_generator() -> Tuple[MalfunctionGenerator, MalfunctionProcess ...@@ -103,7 +96,7 @@ def no_malfunction_generator() -> Tuple[MalfunctionGenerator, MalfunctionProcess
min_number_of_steps_broken = 0 min_number_of_steps_broken = 0
max_number_of_steps_broken = 0 max_number_of_steps_broken = 0
def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]: def generator(np_random: RandomState = None) -> Malfunction:
return Malfunction(0) return Malfunction(0)
return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken, return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
...@@ -162,7 +155,7 @@ def single_malfunction_generator(earlierst_malfunction: int, malfunction_duratio ...@@ -162,7 +155,7 @@ def single_malfunction_generator(earlierst_malfunction: int, malfunction_duratio
malfunction_calls[agent.handle] = 1 malfunction_calls[agent.handle] = 1
# Break an agent that is active at the time of the malfunction # Break an agent that is active at the time of the malfunction
if agent.status == RailAgentStatus.ACTIVE and malfunction_calls[agent.handle] >= earlierst_malfunction: if agent.status == RailAgentStatus.ACTIVE and malfunction_calls[agent.handle] >= earlierst_malfunction: #TODO : Dipam : Is this needed?
global_nr_malfunctions += 1 global_nr_malfunctions += 1
return Malfunction(malfunction_duration) return Malfunction(malfunction_duration)
else: else:
......
This diff is collapsed.
...@@ -19,6 +19,10 @@ class RailEnvActions(IntEnum): ...@@ -19,6 +19,10 @@ class RailEnvActions(IntEnum):
4: 'S', 4: 'S',
}[a] }[a]
@staticmethod
def is_moving_action(action):
return action in [1,2,3]
RailEnvGridPos = NamedTuple('RailEnvGridPos', [('r', int), ('c', int)]) RailEnvGridPos = NamedTuple('RailEnvGridPos', [('r', int), ('c', int)])
RailEnvNextAction = NamedTuple('RailEnvNextAction', [('action', RailEnvActions), ('next_position', RailEnvGridPos), RailEnvNextAction = NamedTuple('RailEnvNextAction', [('action', RailEnvActions), ('next_position', RailEnvGridPos),
......
from flatland.core.grid.grid_utils import position_to_coordinate
from flatland.envs.agent_utils import TrainState
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.step_utils.transition_utils import check_valid_action
def process_illegal_action(action: RailEnvActions):
# TODO - Dipam : This check is kind of weird, change this
if action is None or action not in RailEnvActions._value2member_map_:
return RailEnvActions.DO_NOTHING
else:
return action
def process_do_nothing(state: TrainState):
if state == TrainState.MOVING:
action = RailEnvActions.MOVE_FORWARD
else:
action = RailEnvActions.STOP_MOVING
return action
def process_left_right(action, state, rail, position, direction):
if not check_valid_action(action, state, rail, position, direction):
action = RailEnvActions.MOVE_FORWARD
return action
def preprocess_action_when_waiting(action, state):
"""
Set action to DO_NOTHING if in waiting state
"""
if state == TrainState.WAITING:
action = RailEnvActions.DO_NOTHING
return action
def preprocess_raw_action(action, state):
"""
Preprocesses actions to handle different situations of usage of action based on context
- DO_NOTHING is converted to FORWARD if train is moving
- DO_NOTHING is converted to STOP_MOVING if train is moving
"""
action = process_illegal_action(action)
if action == RailEnvActions.DO_NOTHING:
action = process_do_nothing(state)
return action
def preprocess_moving_action(action, state, rail, position, direction):
"""
LEFT/RIGHT is converted to FORWARD if left/right is not available and train is moving
FORWARD is converted to STOP_MOVING if leading to dead end?
"""
if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]:
action = process_left_right(action, rail, position, direction)
if not check_valid_action(action, rail, position, direction): # TODO: Dipam - Not sure if this is needed
action = RailEnvActions.STOP_MOVING
return action
\ No newline at end of file
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.step_utils.states import TrainState
class ActionSaver:
def __init__(self):
self.saved_action = None
@property
def is_action_saved(self):
return self.saved_action is not None
def __repr__(self):
return f"is_action_saved: {self.is_action_saved}, saved_action: {self.saved_action}"
def save_action_if_allowed(self, action, state):
if not self.is_action_saved and \
RailEnvActions.is_moving_action(action) and \
not TrainState.is_malfunction_state(state):
self.saved_action = action
def clear_saved_action(self):
self.saved_action = None
def get_number_of_steps_to_break(malfunction_generator, np_random):
if hasattr(malfunction_generator, "generate"):
malfunction = malfunction_generator.generate(np_random)
else:
malfunction = malfunction_generator(np_random)
return malfunction.num_broken_steps
class MalfunctionHandler:
def __init__(self):
self._malfunction_down_counter = 0
@property
def in_malfunction(self):
return self._malfunction_down_counter > 0
@property
def malfunction_counter_complete(self):
return self._malfunction_down_counter == 0
@property
def malfunction_down_counter(self):
return self._malfunction_down_counter
@malfunction_down_counter.setter
def malfunction_down_counter(self, val):
self._set_malfunction_down_counter(val)
def _set_malfunction_down_counter(self, val):
if val < 0:
raise ValueError("Cannot set a negative value to malfunction down counter")
self._malfunction_down_counter = val
def generate_malfunction(self, malfunction_generator, np_random):
num_broken_steps = get_number_of_steps_to_break(malfunction_generator, np_random)
self._set_malfunction_down_counter(num_broken_steps)
def update_counter(self):
if self._malfunction_down_counter > 0:
self._malfunction_down_counter -= 1
import numpy as np
from flatland.envs.step_utils.states import TrainState
class SpeedCounter:
def __init__(self, speed):
self.speed = speed
self.max_count = int(1/speed)
def update_counter(self, state):
if state == TrainState.MOVING:
self.counter += 1
self.counter = self.counter % self.max_count
def __repr__(self):
return f"speed: {self.speed} \
max_count: {self.max_count} \
is_cell_entry: {self.is_cell_entry} \
is_cell_exit: {self.is_cell_exit} \
counter: {self.counter}"
def reset_counter(self):
self.counter = 0
@property
def is_cell_entry(self):
return self.counter == 0
@property
def is_cell_exit(self):
return self.counter == self.max_count - 1
from attr import s
from flatland.envs.step_utils.states import TrainState
class TrainStateMachine:
def __init__(self, initial_state=TrainState.WAITING):
self._initial_state = initial_state
self._state = initial_state
self.st_signals = {} # State Transition Signals # TODO: Make this namedtuple
self.next_state = None
def _handle_waiting(self):
"""" Waiting state goes to ready to depart when earliest departure is reached"""
# TODO: Important - The malfunction handling is not like proper state machine
# Both transition signals can happen at the same time
# Atleast mention it in the diagram
if self.st_signals['malfunction_onset']:
self.next_state = TrainState.MALFUNCTION_OFF_MAP
elif self.st_signals['earliest_departure_reached']:
self.next_state = TrainState.READY_TO_DEPART
else:
self.next_state = TrainState.WAITING
def _handle_ready_to_depart(self):
""" Can only go to MOVING if a valid action is provided """
if self.st_signals['malfunction_onset']:
self.next_state = TrainState.MALFUNCTION_OFF_MAP
elif self.st_signals['valid_movement_action_given']:
self.next_state = TrainState.MOVING
else:
self.next_state = TrainState.READY_TO_DEPART
def _handle_malfunction_off_map(self):
if self.st_signals['malfunction_counter_complete']:
if self.st_signals['earliest_departure_reached']:
self.next_state = TrainState.READY_TO_DEPART
else:
self.next_state = TrainState.STOPPED
else:
self.next_state = TrainState.WAITING
def _handle_moving(self):
if self.st_signals['malfunction_onset']:
self.next_state = TrainState.MALFUNCTION
elif self.st_signals['target_reached']:
self.next_state = TrainState.DONE
elif self.st_signals['stop_action_given'] or self.st_signals['movement_conflict']:
self.next_state = TrainState.STOPPED
else:
self.next_state = TrainState.MOVING
def _handle_stopped(self):
if self.st_signals['malfunction_onset']:
self.next_state = TrainState.MALFUNCTION
elif self.st_signals['valid_movement_action_given']:
self.next_state = TrainState.MOVING
else:
self.next_state = TrainState.STOPPED
def _handle_malfunction(self):
if self.st_signals['malfunction_counter_complete'] and \
self.st_signals['valid_movement_action_given']:
self.next_state = TrainState.MOVING
elif self.st_signals['malfunction_counter_complete'] and \
(self.st_signals['stop_action_given'] or self.st_signals['movement_conflict']):
self.next_state = TrainState.STOPPED
else:
self.next_state = TrainState.MALFUNCTION
def _handle_done(self):
"""" Done state is terminal """
self.next_state = TrainState.DONE
def calculate_next_state(self, current_state):
# _Handle the current state
if current_state == TrainState.WAITING:
self._handle_waiting()
elif current_state == TrainState.READY_TO_DEPART:
self._handle_ready_to_depart()
elif current_state == TrainState.MALFUNCTION_OFF_MAP:
self._handle_malfunction_off_map()
elif current_state == TrainState.MOVING:
self._handle_moving()
elif current_state == TrainState.STOPPED:
self._handle_stopped()
elif current_state == TrainState.MALFUNCTION:
self._handle_malfunction()
elif current_state == TrainState.DONE:
self._handle_done()
else:
raise ValueError(f"Got unexpected state {current_state}")
def step(self):
""" Steps the state machine to the next state """
current_state = self._state
# Clear next state
self.clear_next_state()
# Handle current state to get next_state
self.calculate_next_state(current_state)
# Set next state
self.set_state(self.next_state)
def clear_next_state(self):
self.next_state = None
def set_state(self, state):
if not TrainState.check_valid_state(state):
raise ValueError(f"Cannot set invalid state {state}")
self._state = state
def reset(self):
self._state = self._initial_state
self.st_signals = {}
self.clear_next_state()
@property
def state(self):
return self._state
@property
def state_transition_signals(self):
return self.st_signals
def set_transition_signals(self, state_transition_signals):
self.st_signals = state_transition_signals # TODO: Important: Check all keys are present and if not raise error
from enum import IntEnum
class TrainState(IntEnum):
WAITING = 0
READY_TO_DEPART = 1
MALFUNCTION_OFF_MAP = 2
MOVING = 3
STOPPED = 4
MALFUNCTION = 5
DONE = 6
@classmethod
def check_valid_state(cls, state):
return state in cls._value2member_map_
@staticmethod
def is_malfunction_state(state):
return state in [2, 5] # TODO: Can this be done with names instead?
from typing import Tuple
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.rail_env_action import RailEnvActions
def check_action(action, position, direction, rail):
"""
Parameters
----------
agent : EnvAgent
action : RailEnvActions
Returns
-------
Tuple[Grid4TransitionsEnum,Tuple[int,int]]
"""
transition_valid = None
possible_transitions = rail.get_transitions(*position, direction)
num_transitions = fast_count_nonzero(possible_transitions)
new_direction = direction
if action == RailEnvActions.MOVE_LEFT:
new_direction = direction - 1
if num_transitions <= 1:
transition_valid = False
elif action == RailEnvActions.MOVE_RIGHT:
new_direction = direction + 1
if num_transitions <= 1:
transition_valid = False
new_direction %= 4 # Dipam : Why?
if action == RailEnvActions.MOVE_FORWARD and num_transitions == 1:
# - dead-end, straight line or curved line;
# new_direction will be the only valid transition
# - take only available transition
new_direction = fast_argmax(possible_transitions)
transition_valid = True
return new_direction, transition_valid
def check_action_on_agent(action, rail, position, direction):
"""
Parameters
----------
action : RailEnvActions
agent : EnvAgent
Returns
-------
bool
Is it a legal move?
1) transition allows the new_direction in the cell,
2) the new cell is not empty (case 0),
3) the cell is free, i.e., no agent is currently in that cell
"""
# compute number of possible transitions in the current
# cell used to check for invalid actions
new_direction, transition_valid = check_action(action, position, direction, rail)
new_position = get_new_position(position, new_direction)
cell_inside_grid = check_bounds(new_position, rail.height, rail.width)
cell_not_empty = rail.get_full_transitions(*new_position) > 0
new_cell_valid = cell_inside_grid and cell_not_empty
# If transition validity hasn't been checked yet.
if transition_valid is None:
transition_valid = rail.get_transition( # TODO: Dipam - Read this one
(*position, direction),
new_direction)
return new_cell_valid, new_direction, new_position, transition_valid
def check_valid_action(action, rail, position, direction):
new_cell_valid, _, _, transition_valid = check_action_on_agent(action, rail, position, direction)
action_is_valid = new_cell_valid and transition_valid
return action_is_valid
def fast_argmax(possible_transitions: Tuple[int, int, int, int]) -> bool:
if possible_transitions[0] == 1:
return 0
if possible_transitions[1] == 1:
return 1
if possible_transitions[2] == 1:
return 2
return 3
def fast_count_nonzero(possible_transitions: Tuple[int, int, int, int]):
return possible_transitions[0] + possible_transitions[1] + possible_transitions[2] + possible_transitions[3]
def check_bounds(position, height, width):
return position[0] >= 0 and position[1] >= 0 and position[0] < height and position[1] < width
\ No newline at end of file
import numpy as np
import numpy as np
import os
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters, ParamMalfunctionGen
from flatland.envs.observations import GlobalObsForRailEnv
# First of all we import the Flatland rail environment
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env import RailEnvActions
from flatland.envs.rail_generators import sparse_rail_generator
#from flatland.envs.sparse_rail_gen import SparseRailGen
from flatland.envs.schedule_generators import sparse_schedule_generator
def get_small_two_agent_env():
"""Generates a simple 2 city 2 train env returns it after reset"""
width = 30 # With of map
height = 15 # Height of map
nr_trains = 2 # Number of trains that have an assigned task in the env
cities_in_map = 2 # Number of cities where agents can start or end
seed = 42 # Random seed
grid_distribution_of_cities = False # Type of city distribution, if False cities are randomly placed
max_rails_between_cities = 2 # Max number of tracks allowed between cities. This is number of entry point to a city
max_rail_in_cities = 6 # Max number of parallel tracks within a city, representing a realistic trainstation
rail_generator = sparse_rail_generator(max_num_cities=cities_in_map,
seed=seed,
grid_mode=grid_distribution_of_cities,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rail_in_cities//2,
)
speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
schedule_generator = sparse_schedule_generator(speed_ration_map)
stochastic_data = MalfunctionParameters(malfunction_rate=1/10000, # Rate of malfunction occurence
min_duration=15, # Minimal duration of malfunction
max_duration=50 # Max duration of malfunction
)
observation_builder = GlobalObsForRailEnv()
env = RailEnv(width=width,
height=height,
rail_generator=rail_generator,
schedule_generator=schedule_generator,
number_of_agents=nr_trains,
obs_builder_object=observation_builder,
#malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
malfunction_generator=ParamMalfunctionGen(stochastic_data),
remove_agents_at_target=True,
random_seed=seed)
env.reset()
return env
\ No newline at end of file
from test_env_step_utils import get_small_two_agent_env
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.step_utils.states import TrainState
from flatland.envs.malfunction_generators import Malfunction
class NoMalfunctionGenerator:
def generate(self, np_random):
return Malfunction(0)
class AlwaysThreeStepMalfunction:
def generate(self, np_random):
return Malfunction(3)
def test_waiting_no_transition():
env = get_small_two_agent_env()
env.malfunction_generator = NoMalfunctionGenerator()
i_agent = 0
ed = env.agents[i_agent].earliest_departure
for _ in range(ed-1):
env.step({i_agent: RailEnvActions.MOVE_FORWARD})
assert env.agents[i_agent].state == TrainState.WAITING
def test_waiting_to_ready_to_depart():
env = get_small_two_agent_env()
env.malfunction_generator = NoMalfunctionGenerator()
i_agent = 0
ed = env.agents[i_agent].earliest_departure
for _ in range(ed):
env.step({i_agent: RailEnvActions.DO_NOTHING})
assert env.agents[i_agent].state == TrainState.READY_TO_DEPART
def test_ready_to_depart_to_moving():
env = get_small_two_agent_env()
env.malfunction_generator = NoMalfunctionGenerator()
i_agent = 0
ed = env.agents[i_agent].earliest_departure
for _ in range(ed):
env.step({i_agent: RailEnvActions.DO_NOTHING})
env.step({i_agent: RailEnvActions.MOVE_FORWARD})
assert env.agents[i_agent].state == TrainState.MOVING
def test_moving_to_stopped():
env = get_small_two_agent_env()
env.malfunction_generator = NoMalfunctionGenerator()
i_agent = 0
ed = env.agents[i_agent].earliest_departure
for _ in range(ed):
env.step({i_agent: RailEnvActions.DO_NOTHING})
env.step({i_agent: RailEnvActions.MOVE_FORWARD})
env.step({i_agent: RailEnvActions.STOP_MOVING})
assert env.agents[i_agent].state == TrainState.STOPPED
def test_stopped_to_moving():
env = get_small_two_agent_env()
env.malfunction_generator = NoMalfunctionGenerator()
i_agent = 0
ed = env.agents[i_agent].earliest_departure
for _ in range(ed):
env.step({i_agent: RailEnvActions.DO_NOTHING})
env.step({i_agent: RailEnvActions.MOVE_FORWARD})
env.step({i_agent: RailEnvActions.STOP_MOVING})
env.step({i_agent: RailEnvActions.MOVE_FORWARD})
assert env.agents[i_agent].state == TrainState.MOVING
def test_moving_to_done():
env = get_small_two_agent_env()
env.malfunction_generator = NoMalfunctionGenerator()
i_agent = 1
ed = env.agents[i_agent].earliest_departure
for _ in range(ed):
env.step({i_agent: RailEnvActions.DO_NOTHING})
for _ in range(50):
env.step({i_agent: RailEnvActions.MOVE_FORWARD})
assert env.agents[i_agent].state == TrainState.DONE
def test_waiting_to_malfunction():
env = get_small_two_agent_env()
env.malfunction_generator = AlwaysThreeStepMalfunction()
i_agent = 1
env.step({i_agent: RailEnvActions.DO_NOTHING})
assert env.agents[i_agent].state == TrainState.MALFUNCTION_OFF_MAP
def test_ready_to_depart_to_malfunction_off_map():
env = get_small_two_agent_env()
env.malfunction_generator = NoMalfunctionGenerator()
i_agent = 1
env.step({i_agent: RailEnvActions.DO_NOTHING})
ed = env.agents[i_agent].earliest_departure
for _ in range(ed):
env.step({i_agent: RailEnvActions.DO_NOTHING}) # This should get into ready to depart
env.malfunction_generator = AlwaysThreeStepMalfunction()
env.step({i_agent: RailEnvActions.DO_NOTHING})
assert env.agents[i_agent].state == TrainState.MALFUNCTION_OFF_MAP
def test_malfunction_off_map_to_waiting():
env = get_small_two_agent_env()
env.malfunction_generator = NoMalfunctionGenerator()
i_agent = 1
env.step({i_agent: RailEnvActions.DO_NOTHING})
ed = env.agents[i_agent].earliest_departure
for _ in range(ed):
env.step({i_agent: RailEnvActions.DO_NOTHING}) # This should get into ready to depart
env.malfunction_generator = AlwaysThreeStepMalfunction()
env.step({i_agent: RailEnvActions.DO_NOTHING})
assert env.agents[i_agent].state == TrainState.MALFUNCTION_OFF_MAP
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment