Commit 9301e30b authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

malfunction fix to let previous malfunction finish

parent 6db05ca9
...@@ -372,7 +372,7 @@ class RailEnv(Environment): ...@@ -372,7 +372,7 @@ class RailEnv(Environment):
""" Generate State Transitions Signals used in the state machine """ """ Generate State Transitions Signals used in the state machine """
st_signals = StateTransitionSignals() st_signals = StateTransitionSignals()
# Malfunction onset - Malfunction starts # Malfunction starts when in_malfunction is set to true
st_signals.malfunction_onset = agent.malfunction_handler.in_malfunction st_signals.malfunction_onset = agent.malfunction_handler.in_malfunction
# Malfunction counter complete - Malfunction ends next timestep # Malfunction counter complete - Malfunction ends next timestep
...@@ -563,7 +563,8 @@ class RailEnv(Environment): ...@@ -563,7 +563,8 @@ class RailEnv(Environment):
agent.state_machine.step() agent.state_machine.step()
if agent.state.is_on_map_state() and agent.position is None: if agent.state.is_on_map_state() and agent.position is None:
import pdb; pdb.set_trace() raise ValueError("Agent ID {} Agent State {} not matching with Agent Position {} ".format(
agent.handle, str(agent.state), str(agent.position) ))
# Handle done state actions, optionally remove agents # Handle done state actions, optionally remove agents
self.handle_done_state(agent) self.handle_done_state(agent)
......
...@@ -30,7 +30,9 @@ class MalfunctionHandler: ...@@ -30,7 +30,9 @@ class MalfunctionHandler:
def _set_malfunction_down_counter(self, val): def _set_malfunction_down_counter(self, val):
if val < 0: if val < 0:
raise ValueError("Cannot set a negative value to malfunction down counter") raise ValueError("Cannot set a negative value to malfunction down counter")
self._malfunction_down_counter = val # Only set new malfunction value if old malfunction is completed
if self._malfunction_down_counter == 0:
self._malfunction_down_counter = val
def generate_malfunction(self, malfunction_generator, np_random): def generate_malfunction(self, malfunction_generator, np_random):
num_broken_steps = get_number_of_steps_to_break(malfunction_generator, np_random) num_broken_steps = get_number_of_steps_to_break(malfunction_generator, np_random)
...@@ -40,6 +42,10 @@ class MalfunctionHandler: ...@@ -40,6 +42,10 @@ class MalfunctionHandler:
if self._malfunction_down_counter > 0: if self._malfunction_down_counter > 0:
self._malfunction_down_counter -= 1 self._malfunction_down_counter -= 1
def __repr__(self):
return f"malfunction_down_counter: {self._malfunction_down_counter} \
in_malfunction: {self.in_malfunction}"
def to_dict(self): def to_dict(self):
return {"malfunction_down_counter": self._malfunction_down_counter} return {"malfunction_down_counter": self._malfunction_down_counter}
......
...@@ -13,7 +13,7 @@ class TrainStateMachine: ...@@ -13,7 +13,7 @@ class TrainStateMachine:
# TODO: Important - The malfunction handling is not like proper state machine # TODO: Important - The malfunction handling is not like proper state machine
# Both transition signals can happen at the same time # Both transition signals can happen at the same time
# Atleast mention it in the diagram # Atleast mention it in the diagram
if self.st_signals.malfunction_onset: if self.st_signals.in_malfunction:
self.next_state = TrainState.MALFUNCTION_OFF_MAP self.next_state = TrainState.MALFUNCTION_OFF_MAP
elif self.st_signals.earliest_departure_reached: elif self.st_signals.earliest_departure_reached:
self.next_state = TrainState.READY_TO_DEPART self.next_state = TrainState.READY_TO_DEPART
...@@ -22,7 +22,7 @@ class TrainStateMachine: ...@@ -22,7 +22,7 @@ class TrainStateMachine:
def _handle_ready_to_depart(self): def _handle_ready_to_depart(self):
""" Can only go to MOVING if a valid action is provided """ """ Can only go to MOVING if a valid action is provided """
if self.st_signals.malfunction_onset: if self.st_signals.in_malfunction:
self.next_state = TrainState.MALFUNCTION_OFF_MAP self.next_state = TrainState.MALFUNCTION_OFF_MAP
elif self.st_signals.valid_movement_action_given: elif self.st_signals.valid_movement_action_given:
self.next_state = TrainState.MOVING self.next_state = TrainState.MOVING
...@@ -39,7 +39,7 @@ class TrainStateMachine: ...@@ -39,7 +39,7 @@ class TrainStateMachine:
self.next_state = TrainState.WAITING self.next_state = TrainState.WAITING
def _handle_moving(self): def _handle_moving(self):
if self.st_signals.malfunction_onset: if self.st_signals.in_malfunction:
self.next_state = TrainState.MALFUNCTION self.next_state = TrainState.MALFUNCTION
elif self.st_signals.target_reached: elif self.st_signals.target_reached:
self.next_state = TrainState.DONE self.next_state = TrainState.DONE
...@@ -49,7 +49,7 @@ class TrainStateMachine: ...@@ -49,7 +49,7 @@ class TrainStateMachine:
self.next_state = TrainState.MOVING self.next_state = TrainState.MOVING
def _handle_stopped(self): def _handle_stopped(self):
if self.st_signals.malfunction_onset: if self.st_signals.in_malfunction:
self.next_state = TrainState.MALFUNCTION self.next_state = TrainState.MALFUNCTION
elif self.st_signals.valid_movement_action_given: elif self.st_signals.valid_movement_action_given:
self.next_state = TrainState.MOVING self.next_state = TrainState.MOVING
......
...@@ -27,7 +27,7 @@ class TrainState(IntEnum): ...@@ -27,7 +27,7 @@ class TrainState(IntEnum):
@dataclass(repr=True) @dataclass(repr=True)
class StateTransitionSignals: class StateTransitionSignals:
malfunction_onset : bool = False in_malfunction : bool = False
malfunction_counter_complete : bool = False malfunction_counter_complete : bool = False
earliest_departure_reached : bool = False earliest_departure_reached : bool = False
stop_action_given : bool = False stop_action_given : bool = False
......
...@@ -23,3 +23,4 @@ networkx ...@@ -23,3 +23,4 @@ networkx
ipycanvas ipycanvas
graphviz graphviz
imageio imageio
dataclasses
...@@ -107,11 +107,9 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: ...@@ -107,11 +107,9 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
for a, test_config in enumerate(test_configs): for a, test_config in enumerate(test_configs):
agent: EnvAgent = env.agents[a] agent: EnvAgent = env.agents[a]
replay = test_config.replay[step] replay = test_config.replay[step]
print(agent.position, replay.position, agent.state, agent.speed_counter) print(agent.position, replay.position, agent.state, agent.speed_counter)
# import pdb; pdb.set_trace() _assert(a, agent.position, replay.position, 'position')
# _assert(a, agent.position, replay.position, 'position') _assert(a, agent.direction, replay.direction, 'direction')
# _assert(a, agent.direction, replay.direction, 'direction')
if replay.state is not None: if replay.state is not None:
_assert(a, agent.state, replay.state, 'state') _assert(a, agent.state, replay.state, 'state')
...@@ -129,10 +127,8 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: ...@@ -129,10 +127,8 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
# As we force malfunctions on the agents we have to set a positive rate that the env # As we force malfunctions on the agents we have to set a positive rate that the env
# recognizes the agent as potentially malfuncitoning # recognizes the agent as potentially malfuncitoning
# We also set next malfunction to infitiy to avoid interference with our tests # We also set next malfunction to infitiy to avoid interference with our tests
agent.malfunction_data['malfunction'] = replay.set_malfunction env.agents[a].malfunction_handler._set_malfunction_down_counter(replay.set_malfunction)
agent.malfunction_data['moving_before_malfunction'] = agent.moving _assert(a, agent.malfunction_handler.malfunction_down_counter, replay.malfunction, 'malfunction')
agent.malfunction_data['fixed'] = False
# _assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
print(step) print(step)
_, rewards_dict, _, info_dict = env.step(action_dict) _, rewards_dict, _, info_dict = env.step(action_dict)
if rendering: if rendering:
...@@ -143,8 +139,6 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: ...@@ -143,8 +139,6 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
if not skip_reward_check: if not skip_reward_check:
_assert(a, rewards_dict[a], replay.reward, 'reward') _assert(a, rewards_dict[a], replay.reward, 'reward')
assert False
def create_and_save_env(file_name: str, line_generator: LineGenerator, rail_generator: RailGenerator): def create_and_save_env(file_name: str, line_generator: LineGenerator, rail_generator: RailGenerator):
stochastic_data = MalfunctionParameters(malfunction_rate=1000, # Rate of malfunction occurence stochastic_data = MalfunctionParameters(malfunction_rate=1000, # Rate of malfunction occurence
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment