Commit 88110ae6 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

malfunction added WIP

parent b2587841
......@@ -11,13 +11,7 @@ from flatland.envs.schedule_utils import Schedule
from flatland.envs.step_utils.action_saver import ActionSaver
from flatland.envs.step_utils.speed_counter import SpeedCounter
from flatland.envs.step_utils.state_machine import TrainStateMachine
class RailAgentStatus(IntEnum):
READY_TO_DEPART = 0 # not in grid yet (position is None) -> prediction as if it were at initial position
ACTIVE = 1 # in grid (position is not None), not done -> prediction is remaining path
DONE = 2 # in grid (position is not None), but done -> prediction is stay at target forever
DONE_REMOVED = 3 # removed from grid (position is None) -> prediction is None
from flatland.envs.step_utils.malfunction_handler import MalfunctionHandler
Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
('initial_direction', Grid4TransitionsEnum),
......@@ -29,7 +23,6 @@ Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
('speed_data', dict),
('malfunction_data', dict),
('handle', int),
('status', RailAgentStatus),
('position', Tuple[int, int]),
('old_direction', Grid4TransitionsEnum),
('old_position', Tuple[int, int]),
......@@ -37,6 +30,7 @@ Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
('action_saver', ActionSaver),
('state', TrainState),
('state_machine', TrainStateMachine),
('malfunction_handler', MalfunctionHandler),
])
......@@ -69,13 +63,14 @@ class EnvAgent:
handle = attrib(default=None)
# Env step facelift
action_saver = attrib(default=None)
speed_counter = attrib(default=None)
state_machine = attrib(default=None)
speed_counter = attrib(default = None, type=SpeedCounter)
action_saver = attrib(default = Factory(lambda: ActionSaver()), type=ActionSaver)
state_machine = attrib(default= Factory(lambda: TrainStateMachine(initial_state=TrainState.WAITING)) ,
type=TrainStateMachine)
malfunction_handler = attrib(default = Factory(lambda: MalfunctionHandler()), type=MalfunctionHandler)
state = attrib(default=TrainState.WAITING, type=TrainState)
status = attrib(default=RailAgentStatus.READY_TO_DEPART, type=RailAgentStatus)
position = attrib(default=None, type=Optional[Tuple[int, int]])
# used in rendering
......@@ -90,7 +85,6 @@ class EnvAgent:
self.position = None
# TODO: set direction to None: https://gitlab.aicrowd.com/flatland/flatland/issues/280
self.direction = self.initial_direction
self.status = RailAgentStatus.READY_TO_DEPART
self.old_position = None
self.old_direction = None
self.moving = False
......@@ -119,24 +113,27 @@ class EnvAgent:
speed_data=self.speed_data,
malfunction_data=self.malfunction_data,
handle=self.handle,
status=self.status,
state=self.state,
position=self.position,
old_direction=self.old_direction,
old_position=self.old_position,
speed_counter=self.speed_counter,
action_saver=self.action_saver,
state_machine=self.state_machine)
state_machine=self.state_machine,
malfunction_handler=self.malfunction_handler)
@classmethod
def from_schedule(cls, schedule: Schedule):
""" Create a list of EnvAgent from lists of positions, directions and targets
"""
speed_datas = []
speed_counters = []
for i in range(len(schedule.agent_positions)):
speed = schedule.agent_speeds[i] if schedule.agent_speeds is not None else 1.0
speed_datas.append({'position_fraction': 0.0,
'speed': schedule.agent_speeds[i] if schedule.agent_speeds is not None else 1.0,
'speed': speed,
'transition_action_on_cellexit': 0})
speed_counters.append( SpeedCounter(speed=speed) )
malfunction_datas = []
for i in range(len(schedule.agent_positions)):
......@@ -145,16 +142,6 @@ class EnvAgent:
i] if schedule.agent_malfunction_rates is not None else 0.,
'next_malfunction': 0,
'nr_malfunctions': 0})
action_savers = []
speed_counters = []
state_machines = []
num_agents = len(schedule.agent_positions)
agent_speeds = schedule.agent_speeds or ( [1.0] * num_agents )
for speed in agent_speeds:
speed_counters.append( SpeedCounter(speed=speed) )
action_savers.append( ActionSaver() )
state_machines.append( TrainStateMachine(initial_state=TrainState.WAITING) )
return list(starmap(EnvAgent, zip(schedule.agent_positions, # TODO : Dipam - Really want to change this way of loading agents
schedule.agent_directions,
......@@ -166,9 +153,7 @@ class EnvAgent:
speed_datas,
malfunction_datas,
range(len(schedule.agent_positions)),
action_savers,
speed_counters,
state_machines,
)))
@classmethod
......
......@@ -18,7 +18,7 @@ MalfunctionProcessData = NamedTuple('MalfunctionProcessData',
Malfunction = NamedTuple('Malfunction', [('num_broken_steps', int)])
# Why is the return value Optional? We always return a Malfunction.
MalfunctionGenerator = Callable[[EnvAgent, RandomState, bool], Optional[Malfunction]]
MalfunctionGenerator = Callable[[RandomState, bool], Malfunction]
def _malfunction_prob(rate: float) -> float:
"""
......@@ -42,21 +42,14 @@ class ParamMalfunctionGen(object):
#self.max_number_of_steps_broken = parameters.max_duration
self.MFP = parameters
def generate(self,
agent: EnvAgent = None,
np_random: RandomState = None,
reset=False) -> Optional[Malfunction]:
# Dummy reset function as we don't implement specific seeding here
if reset:
return Malfunction(0)
def generate(self, np_random: RandomState) -> Malfunction:
if agent.malfunction_data['malfunction'] < 1:
if np_random.rand() < _malfunction_prob(self.MFP.malfunction_rate):
num_broken_steps = np_random.randint(self.MFP.min_duration,
self.MFP.max_duration + 1) + 1
return Malfunction(num_broken_steps)
return Malfunction(0)
if np_random.rand() < _malfunction_prob(self.MFP.malfunction_rate):
num_broken_steps = np_random.randint(self.MFP.min_duration,
self.MFP.max_duration + 1) + 1
else:
num_broken_steps = 0
return Malfunction(num_broken_steps)
def get_process_data(self):
return MalfunctionProcessData(*self.MFP)
......@@ -103,7 +96,7 @@ def no_malfunction_generator() -> Tuple[MalfunctionGenerator, MalfunctionProcess
min_number_of_steps_broken = 0
max_number_of_steps_broken = 0
def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
def generator(np_random: RandomState = None) -> Malfunction:
return Malfunction(0)
return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
......@@ -162,7 +155,7 @@ def single_malfunction_generator(earlierst_malfunction: int, malfunction_duratio
malfunction_calls[agent.handle] = 1
# Break an agent that is active at the time of the malfunction
if agent.status == RailAgentStatus.ACTIVE and malfunction_calls[agent.handle] >= earlierst_malfunction:
if agent.status == RailAgentStatus.ACTIVE and malfunction_calls[agent.handle] >= earlierst_malfunction: #TODO : Dipam : Is this needed?
global_nr_malfunctions += 1
return Malfunction(malfunction_duration)
else:
......
This diff is collapsed.
......@@ -2,19 +2,24 @@ from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.step_utils.states import TrainState
class ActionSaver:
def __init__(self):
self.saved_action = None
@property
def is_action_saved(self):
return self.saved_action is not None
def __init__(self):
self.saved_action = None
def save_action_if_allowed(self, action, state):
if not self.is_action_saved and \
RailEnvActions.is_moving_action(action) and \
not TrainState.is_malfunction_state(state):
self.saved_action = action
@property
def is_action_saved(self):
return self.saved_action is not None
def __repr__(self):
return f"is_action_saved: {self.is_action_saved}, saved_action: {self.saved_action}"
def save_action_if_allowed(self, action, state):
if not self.is_action_saved and \
RailEnvActions.is_moving_action(action) and \
not TrainState.is_malfunction_state(state):
self.saved_action = action
def clear_saved_action(self):
self.saved_action = None
def clear_saved_action(self):
self.saved_action = None
def get_number_of_steps_to_break(malfunction_generator, np_random):
if hasattr(malfunction_generator, "generate"):
malfunction = malfunction_generator.generate(np_random)
else:
malfunction = malfunction_generator(np_random)
return malfunction.num_broken_steps
class MalfunctionHandler:
def __init__(self):
self._malfunction_down_counter = 0
@property
def in_malfunction(self):
return self._malfunction_down_counter > 0
@property
def malfunction_counter_complete(self):
return self._malfunction_down_counter == 0
@property
def malfunction_down_counter(self):
return self._malfunction_down_counter
@malfunction_down_counter.setter
def malfunction_down_counter(self, val):
self._set_malfunction_down_counter(val)
def _set_malfunction_down_counter(self, val):
if val < 0:
raise ValueError("Cannot set a negative value to malfunction down counter")
self._malfunction_down_counter = val
def generate_malfunction(self, malfunction_generator, np_random):
num_broken_steps = get_number_of_steps_to_break(malfunction_generator, np_random)
self._set_malfunction_down_counter(num_broken_steps)
def update_counter(self):
if self._malfunction_down_counter > 0:
self._malfunction_down_counter -= 1
......@@ -2,22 +2,30 @@ import numpy as np
from flatland.envs.step_utils.states import TrainState
class SpeedCounter:
def __init__(self, speed):
self.speed = speed
self.max_count = int(1/speed)
def __init__(self, speed):
self.speed = speed
self.max_count = int(1/speed)
def update_counter(self, state):
if state == TrainState.MOVING:
self.counter += 1
self.counter = self.counter % self.max_count
def reset_counter(self):
self.counter = 0
def update_counter(self, state):
if state == TrainState.MOVING:
self.counter += 1
self.counter = self.counter % self.max_count
def __repr__(self):
return f"speed: {self.speed} \
max_count: {self.max_count} \
is_cell_entry: {self.is_cell_entry} \
is_cell_exit: {self.is_cell_exit} \
counter: {self.counter}"
def reset_counter(self):
self.counter = 0
@property
def is_cell_entry(self):
return self.counter == 0
@property
def is_cell_exit(self):
return self.counter == self.max_count - 1
@property
def is_cell_entry(self):
return self.counter == 0
@property
def is_cell_exit(self):
return self.counter == self.max_count - 1
\ No newline at end of file
......@@ -5,7 +5,7 @@ class TrainStateMachine:
def __init__(self, initial_state=TrainState.WAITING):
self._initial_state = initial_state
self._state = initial_state
self.st_signals = {} # State Transition Signals
self.st_signals = {} # State Transition Signals # TODO: Make this namedtuple
self.next_state = None
def _handle_waiting(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment