Commit bbb0eff6 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

Merge branch 'env-step-facelift' into 'env-step-facelift'

Env step facelift

See merge request !320
parents 34fa9914 2cc768c4
Pipeline #8444 failed with stages
in 4 minutes and 5 seconds
......@@ -2,21 +2,19 @@ from flatland.envs.rail_trainrun_data_structures import Waypoint
import numpy as np
from enum import IntEnum
from flatland.envs.step_utils.states import TrainState
from itertools import starmap
from typing import Tuple, Optional, NamedTuple, List
from attr import attr, attrs, attrib, Factory
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.timetable_utils import Line
class RailAgentStatus(IntEnum):
WAITING = 0
READY_TO_DEPART = 1 # not in grid yet (position is None) -> prediction as if it were at initial position
ACTIVE = 2 # in grid (position is not None), not done -> prediction is remaining path
DONE = 3 # in grid (position is not None), but done -> prediction is stay at target forever
DONE_REMOVED = 4 # removed from grid (position is None) -> prediction is None
from flatland.envs.schedule_utils import Schedule
from flatland.envs.step_utils.action_saver import ActionSaver
from flatland.envs.step_utils.speed_counter import SpeedCounter
from flatland.envs.step_utils.state_machine import TrainStateMachine
from flatland.envs.step_utils.malfunction_handler import MalfunctionHandler
Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
('initial_direction', Grid4TransitionsEnum),
......@@ -28,11 +26,16 @@ Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
('speed_data', dict),
('malfunction_data', dict),
('handle', int),
('status', RailAgentStatus),
('position', Tuple[int, int]),
('arrival_time', int),
('old_direction', Grid4TransitionsEnum),
('old_position', Tuple[int, int])])
('old_position', Tuple[int, int]),
('speed_counter', SpeedCounter),
('action_saver', ActionSaver),
('state', TrainState),
('state_machine', TrainStateMachine),
('malfunction_handler', MalfunctionHandler),
])
@attrs
......@@ -65,7 +68,15 @@ class EnvAgent:
handle = attrib(default=None)
# INIT TILL HERE IN _from_line()
status = attrib(default=RailAgentStatus.WAITING, type=RailAgentStatus)
# Env step facelift
speed_counter = attrib(default = None, type=SpeedCounter)
action_saver = attrib(default = Factory(lambda: ActionSaver()), type=ActionSaver)
state_machine = attrib(default= Factory(lambda: TrainStateMachine(initial_state=TrainState.WAITING)) ,
type=TrainStateMachine)
malfunction_handler = attrib(default = Factory(lambda: MalfunctionHandler()), type=MalfunctionHandler)
state = attrib(default=TrainState.WAITING, type=TrainState)
position = attrib(default=None, type=Optional[Tuple[int, int]])
# NEW : EnvAgent Reward Handling
......@@ -75,6 +86,7 @@ class EnvAgent:
old_direction = attrib(default=None)
old_position = attrib(default=None)
def reset(self):
"""
Resets the agents to their initial values of the episode. Called after ScheduleTime generation.
......@@ -82,14 +94,6 @@ class EnvAgent:
self.position = None
# TODO: set direction to None: https://gitlab.aicrowd.com/flatland/flatland/issues/280
self.direction = self.initial_direction
if (self.earliest_departure == 0):
self.status = RailAgentStatus.READY_TO_DEPART
else:
self.status = RailAgentStatus.WAITING
self.arrival_time = None
self.old_position = None
self.old_direction = None
self.moving = False
......@@ -103,48 +107,42 @@ class EnvAgent:
self.malfunction_data['nr_malfunctions'] = 0
self.malfunction_data['moving_before_malfunction'] = False
# NEW : Callables
def get_shortest_path(self, distance_map) -> List[Waypoint]:
from flatland.envs.rail_env_shortest_paths import get_shortest_paths # Circular dep fix
return get_shortest_paths(distance_map=distance_map, agent_handle=self.handle)[self.handle]
def get_travel_time_on_shortest_path(self, distance_map) -> int:
shortest_path = self.get_shortest_path(distance_map)
if shortest_path is not None:
distance = len(shortest_path)
else:
distance = 0
speed = self.speed_data['speed']
return int(np.ceil(distance / speed))
def get_time_remaining_until_latest_arrival(self, elapsed_steps: int) -> int:
return self.latest_arrival - elapsed_steps
def get_current_delay(self, elapsed_steps: int, distance_map) -> int:
'''
+ve if arrival time is projected before latest arrival
-ve if arrival time is projected after latest arrival
'''
return self.get_time_remaining_until_latest_arrival(elapsed_steps) - \
self.get_travel_time_on_shortest_path(distance_map)
self.action_saver.clear_saved_action()
self.speed_counter.reset_counter()
self.state_machine.reset()
def to_agent(self) -> Agent:
return Agent(initial_position=self.initial_position, initial_direction=self.initial_direction,
direction=self.direction, target=self.target, moving=self.moving, earliest_departure=self.earliest_departure,
latest_arrival=self.latest_arrival, speed_data=self.speed_data, malfunction_data=self.malfunction_data,
handle=self.handle, status=self.status, position=self.position, arrival_time=self.arrival_time,
old_direction=self.old_direction, old_position=self.old_position)
return Agent(initial_position=self.initial_position,
initial_direction=self.initial_direction,
direction=self.direction,
target=self.target,
moving=self.moving,
earliest_departure=self.earliest_departure,
latest_arrival=self.latest_arrival,
speed_data=self.speed_data,
malfunction_data=self.malfunction_data,
handle=self.handle,
state=self.state,
position=self.position,
old_direction=self.old_direction,
old_position=self.old_position,
speed_counter=self.speed_counter,
action_saver=self.action_saver,
state_machine=self.state_machine,
malfunction_handler=self.malfunction_handler)
@classmethod
def from_line(cls, line: Line):
""" Create a list of EnvAgent from lists of positions, directions and targets
"""
speed_datas = []
for i in range(len(line.agent_positions)):
speed_counters = []
for i in range(len(schedule.agent_positions)):
speed = schedule.agent_speeds[i] if schedule.agent_speeds is not None else 1.0
speed_datas.append({'position_fraction': 0.0,
'speed': line.agent_speeds[i] if line.agent_speeds is not None else 1.0,
'speed': speed,
'transition_action_on_cellexit': 0})
speed_counters.append( SpeedCounter(speed=speed) )
malfunction_datas = []
for i in range(len(line.agent_positions)):
......@@ -153,17 +151,19 @@ class EnvAgent:
i] if line.agent_malfunction_rates is not None else 0.,
'next_malfunction': 0,
'nr_malfunctions': 0})
return list(starmap(EnvAgent, zip(line.agent_positions,
line.agent_directions,
line.agent_directions,
line.agent_targets,
[False] * len(line.agent_positions),
[None] * len(line.agent_positions), # earliest_departure
[None] * len(line.agent_positions), # latest_arrival
return list(starmap(EnvAgent, zip(schedule.agent_positions, # TODO : Dipam - Really want to change this way of loading agents
schedule.agent_directions,
schedule.agent_directions,
schedule.agent_targets,
[False] * len(schedule.agent_positions),
[None] * len(schedule.agent_positions), # earliest_departure
[None] * len(schedule.agent_positions), # latest_arrival
speed_datas,
malfunction_datas,
range(len(line.agent_positions)))))
range(len(schedule.agent_positions)),
speed_counters,
)))
@classmethod
def load_legacy_static_agent(cls, static_agents_data: Tuple):
......
......@@ -18,7 +18,7 @@ MalfunctionProcessData = NamedTuple('MalfunctionProcessData',
Malfunction = NamedTuple('Malfunction', [('num_broken_steps', int)])
# Why is the return value Optional? We always return a Malfunction.
MalfunctionGenerator = Callable[[EnvAgent, RandomState, bool], Optional[Malfunction]]
MalfunctionGenerator = Callable[[RandomState, bool], Malfunction]
def _malfunction_prob(rate: float) -> float:
"""
......@@ -42,21 +42,14 @@ class ParamMalfunctionGen(object):
#self.max_number_of_steps_broken = parameters.max_duration
self.MFP = parameters
def generate(self,
agent: EnvAgent = None,
np_random: RandomState = None,
reset=False) -> Optional[Malfunction]:
# Dummy reset function as we don't implement specific seeding here
if reset:
return Malfunction(0)
def generate(self, np_random: RandomState) -> Malfunction:
if agent.malfunction_data['malfunction'] < 1:
if np_random.rand() < _malfunction_prob(self.MFP.malfunction_rate):
num_broken_steps = np_random.randint(self.MFP.min_duration,
self.MFP.max_duration + 1) + 1
return Malfunction(num_broken_steps)
return Malfunction(0)
if np_random.rand() < _malfunction_prob(self.MFP.malfunction_rate):
num_broken_steps = np_random.randint(self.MFP.min_duration,
self.MFP.max_duration + 1) + 1
else:
num_broken_steps = 0
return Malfunction(num_broken_steps)
def get_process_data(self):
return MalfunctionProcessData(*self.MFP)
......@@ -103,7 +96,7 @@ def no_malfunction_generator() -> Tuple[MalfunctionGenerator, MalfunctionProcess
min_number_of_steps_broken = 0
max_number_of_steps_broken = 0
def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
def generator(np_random: RandomState = None) -> Malfunction:
return Malfunction(0)
return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
......@@ -162,7 +155,7 @@ def single_malfunction_generator(earlierst_malfunction: int, malfunction_duratio
malfunction_calls[agent.handle] = 1
# Break an agent that is active at the time of the malfunction
if agent.status == RailAgentStatus.ACTIVE and malfunction_calls[agent.handle] >= earlierst_malfunction:
if agent.status == RailAgentStatus.ACTIVE and malfunction_calls[agent.handle] >= earlierst_malfunction: #TODO : Dipam : Is this needed?
global_nr_malfunctions += 1
return Malfunction(malfunction_duration)
else:
......
This diff is collapsed.
......@@ -19,6 +19,10 @@ class RailEnvActions(IntEnum):
4: 'S',
}[a]
@staticmethod
def is_moving_action(action):
return action in [1,2,3]
RailEnvGridPos = NamedTuple('RailEnvGridPos', [('r', int), ('c', int)])
RailEnvNextAction = NamedTuple('RailEnvNextAction', [('action', RailEnvActions), ('next_position', RailEnvGridPos),
......
from flatland.core.grid.grid_utils import position_to_coordinate
from flatland.envs.agent_utils import TrainState
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.step_utils.transition_utils import check_valid_action
def process_illegal_action(action: RailEnvActions):
# TODO - Dipam : This check is kind of weird, change this
if action is None or action not in RailEnvActions._value2member_map_:
return RailEnvActions.DO_NOTHING
else:
return action
def process_do_nothing(state: TrainState):
if state == TrainState.MOVING:
action = RailEnvActions.MOVE_FORWARD
else:
action = RailEnvActions.STOP_MOVING
return action
def process_left_right(action, state, rail, position, direction):
if not check_valid_action(action, state, rail, position, direction):
action = RailEnvActions.MOVE_FORWARD
return action
def preprocess_action_when_waiting(action, state):
"""
Set action to DO_NOTHING if in waiting state
"""
if state == TrainState.WAITING:
action = RailEnvActions.DO_NOTHING
return action
def preprocess_raw_action(action, state):
"""
Preprocesses actions to handle different situations of usage of action based on context
- DO_NOTHING is converted to FORWARD if train is moving
- DO_NOTHING is converted to STOP_MOVING if train is moving
"""
action = process_illegal_action(action)
if action == RailEnvActions.DO_NOTHING:
action = process_do_nothing(state)
return action
def preprocess_moving_action(action, state, rail, position, direction):
"""
LEFT/RIGHT is converted to FORWARD if left/right is not available and train is moving
FORWARD is converted to STOP_MOVING if leading to dead end?
"""
if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]:
action = process_left_right(action, rail, position, direction)
if not check_valid_action(action, rail, position, direction): # TODO: Dipam - Not sure if this is needed
action = RailEnvActions.STOP_MOVING
return action
\ No newline at end of file
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.step_utils.states import TrainState
class ActionSaver:
def __init__(self):
self.saved_action = None
@property
def is_action_saved(self):
return self.saved_action is not None
def __repr__(self):
return f"is_action_saved: {self.is_action_saved}, saved_action: {self.saved_action}"
def save_action_if_allowed(self, action, state):
if not self.is_action_saved and \
RailEnvActions.is_moving_action(action) and \
not TrainState.is_malfunction_state(state):
self.saved_action = action
def clear_saved_action(self):
self.saved_action = None
def get_number_of_steps_to_break(malfunction_generator, np_random):
if hasattr(malfunction_generator, "generate"):
malfunction = malfunction_generator.generate(np_random)
else:
malfunction = malfunction_generator(np_random)
return malfunction.num_broken_steps
class MalfunctionHandler:
def __init__(self):
self._malfunction_down_counter = 0
@property
def in_malfunction(self):
return self._malfunction_down_counter > 0
@property
def malfunction_counter_complete(self):
return self._malfunction_down_counter == 0
@property
def malfunction_down_counter(self):
return self._malfunction_down_counter
@malfunction_down_counter.setter
def malfunction_down_counter(self, val):
self._set_malfunction_down_counter(val)
def _set_malfunction_down_counter(self, val):
if val < 0:
raise ValueError("Cannot set a negative value to malfunction down counter")
self._malfunction_down_counter = val
def generate_malfunction(self, malfunction_generator, np_random):
num_broken_steps = get_number_of_steps_to_break(malfunction_generator, np_random)
self._set_malfunction_down_counter(num_broken_steps)
def update_counter(self):
if self._malfunction_down_counter > 0:
self._malfunction_down_counter -= 1
import numpy as np
from flatland.envs.step_utils.states import TrainState
class SpeedCounter:
def __init__(self, speed):
self.speed = speed
self.max_count = int(1/speed)
def update_counter(self, state):
if state == TrainState.MOVING:
self.counter += 1
self.counter = self.counter % self.max_count
def __repr__(self):
return f"speed: {self.speed} \
max_count: {self.max_count} \
is_cell_entry: {self.is_cell_entry} \
is_cell_exit: {self.is_cell_exit} \
counter: {self.counter}"
def reset_counter(self):
self.counter = 0
@property
def is_cell_entry(self):
return self.counter == 0
@property
def is_cell_exit(self):
return self.counter == self.max_count - 1
from attr import s
from flatland.envs.step_utils.states import TrainState
class TrainStateMachine:
def __init__(self, initial_state=TrainState.WAITING):
self._initial_state = initial_state
self._state = initial_state
self.st_signals = {} # State Transition Signals # TODO: Make this namedtuple
self.next_state = None
def _handle_waiting(self):
"""" Waiting state goes to ready to depart when earliest departure is reached"""
# TODO: Important - The malfunction handling is not like proper state machine
# Both transition signals can happen at the same time
# Atleast mention it in the diagram
if self.st_signals['malfunction_onset']:
self.next_state = TrainState.MALFUNCTION_OFF_MAP
elif self.st_signals['earliest_departure_reached']:
self.next_state = TrainState.READY_TO_DEPART
else:
self.next_state = TrainState.WAITING
def _handle_ready_to_depart(self):
""" Can only go to MOVING if a valid action is provided """
if self.st_signals['malfunction_onset']:
self.next_state = TrainState.MALFUNCTION_OFF_MAP
elif self.st_signals['valid_movement_action_given']:
self.next_state = TrainState.MOVING
else:
self.next_state = TrainState.READY_TO_DEPART
def _handle_malfunction_off_map(self):
if self.st_signals['malfunction_counter_complete']:
if self.st_signals['earliest_departure_reached']:
self.next_state = TrainState.READY_TO_DEPART
else:
self.next_state = TrainState.STOPPED
else:
self.next_state = TrainState.WAITING
def _handle_moving(self):
if self.st_signals['malfunction_onset']:
self.next_state = TrainState.MALFUNCTION
elif self.st_signals['target_reached']:
self.next_state = TrainState.DONE
elif self.st_signals['stop_action_given'] or self.st_signals['movement_conflict']:
self.next_state = TrainState.STOPPED
else:
self.next_state = TrainState.MOVING
def _handle_stopped(self):
if self.st_signals['malfunction_onset']:
self.next_state = TrainState.MALFUNCTION
elif self.st_signals['valid_movement_action_given']:
self.next_state = TrainState.MOVING
else:
self.next_state = TrainState.STOPPED
def _handle_malfunction(self):
if self.st_signals['malfunction_counter_complete'] and \
self.st_signals['valid_movement_action_given']:
self.next_state = TrainState.MOVING
elif self.st_signals['malfunction_counter_complete'] and \
(self.st_signals['stop_action_given'] or self.st_signals['movement_conflict']):
self.next_state = TrainState.STOPPED
else:
self.next_state = TrainState.MALFUNCTION
def _handle_done(self):
"""" Done state is terminal """
self.next_state = TrainState.DONE
def calculate_next_state(self, current_state):
# _Handle the current state
if current_state == TrainState.WAITING:
self._handle_waiting()
elif current_state == TrainState.READY_TO_DEPART:
self._handle_ready_to_depart()
elif current_state == TrainState.MALFUNCTION_OFF_MAP:
self._handle_malfunction_off_map()
elif current_state == TrainState.MOVING:
self._handle_moving()
elif current_state == TrainState.STOPPED:
self._handle_stopped()
elif current_state == TrainState.MALFUNCTION:
self._handle_malfunction()
elif current_state == TrainState.DONE:
self._handle_done()
else:
raise ValueError(f"Got unexpected state {current_state}")
def step(self):
""" Steps the state machine to the next state """
current_state = self._state
# Clear next state
self.clear_next_state()
# Handle current state to get next_state
self.calculate_next_state(current_state)
# Set next state
self.set_state(self.next_state)
def clear_next_state(self):
self.next_state = None
def set_state(self, state):
if not TrainState.check_valid_state(state):
raise ValueError(f"Cannot set invalid state {state}")
self._state = state
def reset(self):
self._state = self._initial_state
self.st_signals = {}
self.clear_next_state()
@property
def state(self):
return self._state
@property
def state_transition_signals(self):
return self.st_signals
def set_transition_signals(self, state_transition_signals):
self.st_signals = state_transition_signals # TODO: Important: Check all keys are present and if not raise error
from enum import IntEnum
class TrainState(IntEnum):
WAITING = 0
READY_TO_DEPART = 1
MALFUNCTION_OFF_MAP = 2
MOVING = 3
STOPPED = 4
MALFUNCTION = 5
DONE = 6
@classmethod
def check_valid_state(cls, state):
return state in cls._value2member_map_
@staticmethod
def is_malfunction_state(state):
return state in [2, 5] # TODO: Can this be done with names instead?