diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py index 0dfafb36447550bc275dadabc6943b50d07e2f4f..086fd9cef348ff8de6b6c358876926804dc673e5 100644 --- a/flatland/envs/malfunction_generators.py +++ b/flatland/envs/malfunction_generators.py @@ -253,7 +253,7 @@ def malfunction_from_params(parameters: MalfunctionParameters) -> Tuple[Malfunct min_number_of_steps_broken = parameters.min_duration max_number_of_steps_broken = parameters.max_duration - def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]: + def generator(np_random: RandomState = None, reset=False) -> Optional[Malfunction]: """ Generate malfunctions for agents Parameters @@ -270,11 +270,10 @@ def malfunction_from_params(parameters: MalfunctionParameters) -> Tuple[Malfunct if reset: return Malfunction(0) - if agent.malfunction_data['malfunction'] < 1: - if np_random.rand() < _malfunction_prob(mean_malfunction_rate): - num_broken_steps = np_random.randint(min_number_of_steps_broken, - max_number_of_steps_broken + 1) + 1 - return Malfunction(num_broken_steps) + if np_random.rand() < _malfunction_prob(mean_malfunction_rate): + num_broken_steps = np_random.randint(min_number_of_steps_broken, + max_number_of_steps_broken + 1) + return Malfunction(num_broken_steps) return Malfunction(0) return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken, diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index a6a15531bf42a45b17f93fecf82dc5c1d2ef92af..f345648533c7ab554403b7031245139e39e51b75 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -549,7 +549,7 @@ class RailEnv(Environment): if agent.malfunction_handler.in_malfunction: movement_allowed = False else: - movement_allowed = self.motionCheck.check_motion(i_agent, agent.position) # TODO: Remove final_new_postion from motioncheck + movement_allowed = self.motionCheck.check_motion(i_agent, agent.position) # Position can be changed only if other cell is empty # And either the speed counter completes or agent is being added to map diff --git a/flatland/envs/step_utils/action_saver.py b/flatland/envs/step_utils/action_saver.py index bf61076e1f1136b3eca28bd05a51ba614a9d0eb6..913e9576d923a7e67ff7a498237803df3d9d0a43 100644 --- a/flatland/envs/step_utils/action_saver.py +++ b/flatland/envs/step_utils/action_saver.py @@ -17,14 +17,10 @@ class ActionSaver: """ Save the action if all conditions are met 1. It is a movement based action -> Forward, Left, Right - 2. Action is not already saved - 3. Not in a malfunction state - 4. Agent is not already done + 2. Action is not already saved + 3. Agent is not already done """ - if action.is_moving_action() and \ - not self.is_action_saved and \ - not state.is_malfunction_state() and \ - not state == TrainState.DONE: + if action.is_moving_action() and not self.is_action_saved and not state == TrainState.DONE: self.saved_action = action def clear_saved_action(self): diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index d633351ed3624499aa2e30df9f09031b0b4cf581..b4632f3e37bd5f879349488d8b24a74c4c5d9759 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -190,8 +190,9 @@ def test_malfunction_before_entry(): # Test initial malfunction values for all agents # we want some agents to be malfuncitoning already and some to be working # we want different next_malfunction values for the agents - assert env.agents[0].malfunction_data['malfunction'] == 0 - assert env.agents[1].malfunction_data['malfunction'] == 10 + malfunction_values = [env.malfunction_generator(env.np_random).num_broken_steps for _ in range(1000)] + expected_value = (1 - np.exp(-0.5)) * 10 + assert np.allclose(np.mean(malfunction_values), expected_value, rtol=0.1), "Mean values of malfunction don't match rate" def test_malfunction_values_and_behavior(): @@ -257,7 +258,7 @@ def test_initial_malfunction(): set_penalties_for_replay(env) replay_config = ReplayConfig( replay=[ - Replay( + Replay( # 0 position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, @@ -265,7 +266,7 @@ def test_initial_malfunction(): malfunction=3, reward=env.step_penalty # full step penalty when malfunctioning ), - Replay( + Replay( # 1 position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, @@ -274,7 +275,7 @@ def test_initial_malfunction(): ), # malfunction stops in the next step and we're still at the beginning of the cell # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell - Replay( + Replay( # 2 position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, @@ -282,14 +283,14 @@ def test_initial_malfunction(): reward=env.step_penalty ), # malfunctioning ends: starting and running at speed 1.0 - Replay( + Replay( # 3 position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=0, reward=env.start_penalty + env.step_penalty * 1.0 # running at speed 1.0 ), - Replay( + Replay( # 4 position=(3, 3), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, @@ -420,7 +421,7 @@ def test_initial_malfunction_do_nothing(): action=RailEnvActions.DO_NOTHING, malfunction=2, reward=env.step_penalty, # full step penalty while malfunctioning - state=TrainState.ACTIVE + state=TrainState.MOVING ), # malfunction stops in the next step and we're still at the beginning of the cell # --> if we take action DO_NOTHING, agent should restart without moving @@ -431,7 +432,7 @@ def test_initial_malfunction_do_nothing(): action=RailEnvActions.DO_NOTHING, malfunction=1, reward=env.step_penalty, # full step penalty while stopped - state=TrainState.ACTIVE + state=TrainState.MOVING ), # we haven't started moving yet --> stay here Replay( @@ -440,7 +441,7 @@ def test_initial_malfunction_do_nothing(): action=RailEnvActions.DO_NOTHING, malfunction=0, reward=env.step_penalty, # full step penalty while stopped - state=TrainState.ACTIVE + state=TrainState.MOVING ), Replay( @@ -449,7 +450,7 @@ def test_initial_malfunction_do_nothing(): action=RailEnvActions.MOVE_FORWARD, malfunction=0, reward=env.start_penalty + env.step_penalty * 1.0, # start penalty + step penalty for speed 1.0 - state=TrainState.ACTIVE + state=TrainState.MOVING ), # we start to move forward --> should go to next cell now Replay( position=(3, 3), @@ -457,7 +458,7 @@ def test_initial_malfunction_do_nothing(): action=RailEnvActions.MOVE_FORWARD, malfunction=0, reward=env.step_penalty * 1.0, # step penalty for speed 1.0 - state=TrainState.ACTIVE + state=TrainState.MOVING ) ], speed=env.agents[0].speed_counter.speed, @@ -546,7 +547,7 @@ def test_last_malfunction_step(): env.reset(False, False) for a_idx in range(len(env.agents)): env.agents[a_idx].position = env.agents[a_idx].initial_position - env.agents[a_idx].state = TrainState.ACTIVE + env.agents[a_idx].state = TrainState.MOVING # Force malfunction to be off at beginning and next malfunction to happen in 2 steps env.agents[0].malfunction_data['next_malfunction'] = 2 env.agents[0].malfunction_data['malfunction'] = 0 diff --git a/tests/test_utils.py b/tests/test_utils.py index 3469c9c82ed3f66ade78f2f335850a9c4809ddc6..391ba535a343c25df03df6fa18f8651543fab06f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -107,7 +107,6 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: for a, test_config in enumerate(test_configs): agent: EnvAgent = env.agents[a] replay = test_config.replay[step] - print(agent.position, replay.position, agent.state, agent.speed_counter) _assert(a, agent.position, replay.position, 'position') _assert(a, agent.direction, replay.direction, 'direction') if replay.state is not None: