Skip to content
Snippets Groups Projects
Commit 9301e30b authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

malfunction fix to let previous malfunction finish

parent 6db05ca9
No related branches found
No related tags found
No related merge requests found
......@@ -372,7 +372,7 @@ class RailEnv(Environment):
""" Generate State Transitions Signals used in the state machine """
st_signals = StateTransitionSignals()
# Malfunction onset - Malfunction starts
# Malfunction starts when in_malfunction is set to true
st_signals.malfunction_onset = agent.malfunction_handler.in_malfunction
# Malfunction counter complete - Malfunction ends next timestep
......@@ -563,7 +563,8 @@ class RailEnv(Environment):
agent.state_machine.step()
if agent.state.is_on_map_state() and agent.position is None:
import pdb; pdb.set_trace()
raise ValueError("Agent ID {} Agent State {} not matching with Agent Position {} ".format(
agent.handle, str(agent.state), str(agent.position) ))
# Handle done state actions, optionally remove agents
self.handle_done_state(agent)
......
......@@ -30,7 +30,9 @@ class MalfunctionHandler:
def _set_malfunction_down_counter(self, val):
if val < 0:
raise ValueError("Cannot set a negative value to malfunction down counter")
self._malfunction_down_counter = val
# Only set new malfunction value if old malfunction is completed
if self._malfunction_down_counter == 0:
self._malfunction_down_counter = val
def generate_malfunction(self, malfunction_generator, np_random):
num_broken_steps = get_number_of_steps_to_break(malfunction_generator, np_random)
......@@ -40,6 +42,10 @@ class MalfunctionHandler:
if self._malfunction_down_counter > 0:
self._malfunction_down_counter -= 1
def __repr__(self):
return f"malfunction_down_counter: {self._malfunction_down_counter} \
in_malfunction: {self.in_malfunction}"
def to_dict(self):
return {"malfunction_down_counter": self._malfunction_down_counter}
......
......@@ -13,7 +13,7 @@ class TrainStateMachine:
# TODO: Important - The malfunction handling is not like proper state machine
# Both transition signals can happen at the same time
# Atleast mention it in the diagram
if self.st_signals.malfunction_onset:
if self.st_signals.in_malfunction:
self.next_state = TrainState.MALFUNCTION_OFF_MAP
elif self.st_signals.earliest_departure_reached:
self.next_state = TrainState.READY_TO_DEPART
......@@ -22,7 +22,7 @@ class TrainStateMachine:
def _handle_ready_to_depart(self):
""" Can only go to MOVING if a valid action is provided """
if self.st_signals.malfunction_onset:
if self.st_signals.in_malfunction:
self.next_state = TrainState.MALFUNCTION_OFF_MAP
elif self.st_signals.valid_movement_action_given:
self.next_state = TrainState.MOVING
......@@ -39,7 +39,7 @@ class TrainStateMachine:
self.next_state = TrainState.WAITING
def _handle_moving(self):
if self.st_signals.malfunction_onset:
if self.st_signals.in_malfunction:
self.next_state = TrainState.MALFUNCTION
elif self.st_signals.target_reached:
self.next_state = TrainState.DONE
......@@ -49,7 +49,7 @@ class TrainStateMachine:
self.next_state = TrainState.MOVING
def _handle_stopped(self):
if self.st_signals.malfunction_onset:
if self.st_signals.in_malfunction:
self.next_state = TrainState.MALFUNCTION
elif self.st_signals.valid_movement_action_given:
self.next_state = TrainState.MOVING
......
......@@ -27,7 +27,7 @@ class TrainState(IntEnum):
@dataclass(repr=True)
class StateTransitionSignals:
malfunction_onset : bool = False
in_malfunction : bool = False
malfunction_counter_complete : bool = False
earliest_departure_reached : bool = False
stop_action_given : bool = False
......
......@@ -23,3 +23,4 @@ networkx
ipycanvas
graphviz
imageio
dataclasses
......@@ -107,11 +107,9 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
for a, test_config in enumerate(test_configs):
agent: EnvAgent = env.agents[a]
replay = test_config.replay[step]
print(agent.position, replay.position, agent.state, agent.speed_counter)
# import pdb; pdb.set_trace()
# _assert(a, agent.position, replay.position, 'position')
# _assert(a, agent.direction, replay.direction, 'direction')
_assert(a, agent.position, replay.position, 'position')
_assert(a, agent.direction, replay.direction, 'direction')
if replay.state is not None:
_assert(a, agent.state, replay.state, 'state')
......@@ -129,10 +127,8 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
# As we force malfunctions on the agents we have to set a positive rate that the env
# recognizes the agent as potentially malfuncitoning
# We also set next malfunction to infitiy to avoid interference with our tests
agent.malfunction_data['malfunction'] = replay.set_malfunction
agent.malfunction_data['moving_before_malfunction'] = agent.moving
agent.malfunction_data['fixed'] = False
# _assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
env.agents[a].malfunction_handler._set_malfunction_down_counter(replay.set_malfunction)
_assert(a, agent.malfunction_handler.malfunction_down_counter, replay.malfunction, 'malfunction')
print(step)
_, rewards_dict, _, info_dict = env.step(action_dict)
if rendering:
......@@ -143,8 +139,6 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
if not skip_reward_check:
_assert(a, rewards_dict[a], replay.reward, 'reward')
assert False
def create_and_save_env(file_name: str, line_generator: LineGenerator, rail_generator: RailGenerator):
stochastic_data = MalfunctionParameters(malfunction_rate=1000, # Rate of malfunction occurence
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment