diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 0a642a4c5b6063a259dee4b8adda3f6948234e6b..6859497f129022877d2174b95a51da076ff3d725 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -372,7 +372,7 @@ class RailEnv(Environment): """ Generate State Transitions Signals used in the state machine """ st_signals = StateTransitionSignals() - # Malfunction onset - Malfunction starts + # Malfunction starts when in_malfunction is set to true st_signals.malfunction_onset = agent.malfunction_handler.in_malfunction # Malfunction counter complete - Malfunction ends next timestep @@ -563,7 +563,8 @@ class RailEnv(Environment): agent.state_machine.step() if agent.state.is_on_map_state() and agent.position is None: - import pdb; pdb.set_trace() + raise ValueError("Agent ID {} Agent State {} not matching with Agent Position {} ".format( + agent.handle, str(agent.state), str(agent.position) )) # Handle done state actions, optionally remove agents self.handle_done_state(agent) diff --git a/flatland/envs/step_utils/malfunction_handler.py b/flatland/envs/step_utils/malfunction_handler.py index a45aa02499983ed8d6444ef4cab11dbd8650f1ed..2ba72643f7c28729b99facf702528f6ffd174add 100644 --- a/flatland/envs/step_utils/malfunction_handler.py +++ b/flatland/envs/step_utils/malfunction_handler.py @@ -30,7 +30,9 @@ class MalfunctionHandler: def _set_malfunction_down_counter(self, val): if val < 0: raise ValueError("Cannot set a negative value to malfunction down counter") - self._malfunction_down_counter = val + # Only set new malfunction value if old malfunction is completed + if self._malfunction_down_counter == 0: + self._malfunction_down_counter = val def generate_malfunction(self, malfunction_generator, np_random): num_broken_steps = get_number_of_steps_to_break(malfunction_generator, np_random) @@ -40,6 +42,10 @@ class MalfunctionHandler: if self._malfunction_down_counter > 0: self._malfunction_down_counter -= 1 + def __repr__(self): + return f"malfunction_down_counter: {self._malfunction_down_counter} \ + in_malfunction: {self.in_malfunction}" + def to_dict(self): return {"malfunction_down_counter": self._malfunction_down_counter} diff --git a/flatland/envs/step_utils/state_machine.py b/flatland/envs/step_utils/state_machine.py index d1938f4f5157497c75969b01ea15a072d6a3a31c..78c9883fb85b3a4f967356515cddfc95235346f9 100644 --- a/flatland/envs/step_utils/state_machine.py +++ b/flatland/envs/step_utils/state_machine.py @@ -13,7 +13,7 @@ class TrainStateMachine: # TODO: Important - The malfunction handling is not like proper state machine # Both transition signals can happen at the same time # Atleast mention it in the diagram - if self.st_signals.malfunction_onset: + if self.st_signals.in_malfunction: self.next_state = TrainState.MALFUNCTION_OFF_MAP elif self.st_signals.earliest_departure_reached: self.next_state = TrainState.READY_TO_DEPART @@ -22,7 +22,7 @@ class TrainStateMachine: def _handle_ready_to_depart(self): """ Can only go to MOVING if a valid action is provided """ - if self.st_signals.malfunction_onset: + if self.st_signals.in_malfunction: self.next_state = TrainState.MALFUNCTION_OFF_MAP elif self.st_signals.valid_movement_action_given: self.next_state = TrainState.MOVING @@ -39,7 +39,7 @@ class TrainStateMachine: self.next_state = TrainState.WAITING def _handle_moving(self): - if self.st_signals.malfunction_onset: + if self.st_signals.in_malfunction: self.next_state = TrainState.MALFUNCTION elif self.st_signals.target_reached: self.next_state = TrainState.DONE @@ -49,7 +49,7 @@ class TrainStateMachine: self.next_state = TrainState.MOVING def _handle_stopped(self): - if self.st_signals.malfunction_onset: + if self.st_signals.in_malfunction: self.next_state = TrainState.MALFUNCTION elif self.st_signals.valid_movement_action_given: self.next_state = TrainState.MOVING diff --git a/flatland/envs/step_utils/states.py b/flatland/envs/step_utils/states.py index d5d85b1bf30c52b770a337a183af52cb7333989b..806113e524112e7aa0a0704ddffce1b8d2db5ffa 100644 --- a/flatland/envs/step_utils/states.py +++ b/flatland/envs/step_utils/states.py @@ -27,7 +27,7 @@ class TrainState(IntEnum): @dataclass(repr=True) class StateTransitionSignals: - malfunction_onset : bool = False + in_malfunction : bool = False malfunction_counter_complete : bool = False earliest_departure_reached : bool = False stop_action_given : bool = False diff --git a/requirements_dev.txt b/requirements_dev.txt index 93414562b79e3c0d5e1a77e42b967dc0ea4028fe..51473c19d41ddfbc14507c643758aece381e62e2 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -23,3 +23,4 @@ networkx ipycanvas graphviz imageio +dataclasses diff --git a/tests/test_utils.py b/tests/test_utils.py index 56b4befc9b6d594750dbf87f5db7d8f8a03aa24a..3469c9c82ed3f66ade78f2f335850a9c4809ddc6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -107,11 +107,9 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: for a, test_config in enumerate(test_configs): agent: EnvAgent = env.agents[a] replay = test_config.replay[step] - print(agent.position, replay.position, agent.state, agent.speed_counter) - # import pdb; pdb.set_trace() - # _assert(a, agent.position, replay.position, 'position') - # _assert(a, agent.direction, replay.direction, 'direction') + _assert(a, agent.position, replay.position, 'position') + _assert(a, agent.direction, replay.direction, 'direction') if replay.state is not None: _assert(a, agent.state, replay.state, 'state') @@ -129,10 +127,8 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: # As we force malfunctions on the agents we have to set a positive rate that the env # recognizes the agent as potentially malfuncitoning # We also set next malfunction to infitiy to avoid interference with our tests - agent.malfunction_data['malfunction'] = replay.set_malfunction - agent.malfunction_data['moving_before_malfunction'] = agent.moving - agent.malfunction_data['fixed'] = False - # _assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction') + env.agents[a].malfunction_handler._set_malfunction_down_counter(replay.set_malfunction) + _assert(a, agent.malfunction_handler.malfunction_down_counter, replay.malfunction, 'malfunction') print(step) _, rewards_dict, _, info_dict = env.step(action_dict) if rendering: @@ -143,8 +139,6 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: if not skip_reward_check: _assert(a, rewards_dict[a], replay.reward, 'reward') - assert False - def create_and_save_env(file_name: str, line_generator: LineGenerator, rail_generator: RailGenerator): stochastic_data = MalfunctionParameters(malfunction_rate=1000, # Rate of malfunction occurence