From cc04d63ff36bb36e07b9b5ac966c6e691518f854 Mon Sep 17 00:00:00 2001 From: Dipam Chakraborty <dipam@aicrowd.com> Date: Tue, 14 Sep 2021 00:36:16 +0530 Subject: [PATCH] malfunction fix: can save actions during malfunction --- flatland/envs/malfunction_generators.py | 11 +++++----- flatland/envs/rail_env.py | 2 +- flatland/envs/step_utils/action_saver.py | 10 +++------ tests/test_flatland_malfunction.py | 27 ++++++++++++------------ tests/test_utils.py | 1 - 5 files changed, 23 insertions(+), 28 deletions(-) diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py index 0dfafb36..086fd9ce 100644 --- a/flatland/envs/malfunction_generators.py +++ b/flatland/envs/malfunction_generators.py @@ -253,7 +253,7 @@ def malfunction_from_params(parameters: MalfunctionParameters) -> Tuple[Malfunct min_number_of_steps_broken = parameters.min_duration max_number_of_steps_broken = parameters.max_duration - def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]: + def generator(np_random: RandomState = None, reset=False) -> Optional[Malfunction]: """ Generate malfunctions for agents Parameters @@ -270,11 +270,10 @@ def malfunction_from_params(parameters: MalfunctionParameters) -> Tuple[Malfunct if reset: return Malfunction(0) - if agent.malfunction_data['malfunction'] < 1: - if np_random.rand() < _malfunction_prob(mean_malfunction_rate): - num_broken_steps = np_random.randint(min_number_of_steps_broken, - max_number_of_steps_broken + 1) + 1 - return Malfunction(num_broken_steps) + if np_random.rand() < _malfunction_prob(mean_malfunction_rate): + num_broken_steps = np_random.randint(min_number_of_steps_broken, + max_number_of_steps_broken + 1) + return Malfunction(num_broken_steps) return Malfunction(0) return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken, diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index a6a15531..f3456485 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -549,7 +549,7 @@ class RailEnv(Environment): if agent.malfunction_handler.in_malfunction: movement_allowed = False else: - movement_allowed = self.motionCheck.check_motion(i_agent, agent.position) # TODO: Remove final_new_postion from motioncheck + movement_allowed = self.motionCheck.check_motion(i_agent, agent.position) # Position can be changed only if other cell is empty # And either the speed counter completes or agent is being added to map diff --git a/flatland/envs/step_utils/action_saver.py b/flatland/envs/step_utils/action_saver.py index bf61076e..913e9576 100644 --- a/flatland/envs/step_utils/action_saver.py +++ b/flatland/envs/step_utils/action_saver.py @@ -17,14 +17,10 @@ class ActionSaver: """ Save the action if all conditions are met 1. It is a movement based action -> Forward, Left, Right - 2. Action is not already saved - 3. Not in a malfunction state - 4. Agent is not already done + 2. Action is not already saved + 3. Agent is not already done """ - if action.is_moving_action() and \ - not self.is_action_saved and \ - not state.is_malfunction_state() and \ - not state == TrainState.DONE: + if action.is_moving_action() and not self.is_action_saved and not state == TrainState.DONE: self.saved_action = action def clear_saved_action(self): diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index d633351e..b4632f3e 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -190,8 +190,9 @@ def test_malfunction_before_entry(): # Test initial malfunction values for all agents # we want some agents to be malfuncitoning already and some to be working # we want different next_malfunction values for the agents - assert env.agents[0].malfunction_data['malfunction'] == 0 - assert env.agents[1].malfunction_data['malfunction'] == 10 + malfunction_values = [env.malfunction_generator(env.np_random).num_broken_steps for _ in range(1000)] + expected_value = (1 - np.exp(-0.5)) * 10 + assert np.allclose(np.mean(malfunction_values), expected_value, rtol=0.1), "Mean values of malfunction don't match rate" def test_malfunction_values_and_behavior(): @@ -257,7 +258,7 @@ def test_initial_malfunction(): set_penalties_for_replay(env) replay_config = ReplayConfig( replay=[ - Replay( + Replay( # 0 position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, @@ -265,7 +266,7 @@ def test_initial_malfunction(): malfunction=3, reward=env.step_penalty # full step penalty when malfunctioning ), - Replay( + Replay( # 1 position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, @@ -274,7 +275,7 @@ def test_initial_malfunction(): ), # malfunction stops in the next step and we're still at the beginning of the cell # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell - Replay( + Replay( # 2 position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, @@ -282,14 +283,14 @@ def test_initial_malfunction(): reward=env.step_penalty ), # malfunctioning ends: starting and running at speed 1.0 - Replay( + Replay( # 3 position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=0, reward=env.start_penalty + env.step_penalty * 1.0 # running at speed 1.0 ), - Replay( + Replay( # 4 position=(3, 3), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, @@ -420,7 +421,7 @@ def test_initial_malfunction_do_nothing(): action=RailEnvActions.DO_NOTHING, malfunction=2, reward=env.step_penalty, # full step penalty while malfunctioning - state=TrainState.ACTIVE + state=TrainState.MOVING ), # malfunction stops in the next step and we're still at the beginning of the cell # --> if we take action DO_NOTHING, agent should restart without moving @@ -431,7 +432,7 @@ def test_initial_malfunction_do_nothing(): action=RailEnvActions.DO_NOTHING, malfunction=1, reward=env.step_penalty, # full step penalty while stopped - state=TrainState.ACTIVE + state=TrainState.MOVING ), # we haven't started moving yet --> stay here Replay( @@ -440,7 +441,7 @@ def test_initial_malfunction_do_nothing(): action=RailEnvActions.DO_NOTHING, malfunction=0, reward=env.step_penalty, # full step penalty while stopped - state=TrainState.ACTIVE + state=TrainState.MOVING ), Replay( @@ -449,7 +450,7 @@ def test_initial_malfunction_do_nothing(): action=RailEnvActions.MOVE_FORWARD, malfunction=0, reward=env.start_penalty + env.step_penalty * 1.0, # start penalty + step penalty for speed 1.0 - state=TrainState.ACTIVE + state=TrainState.MOVING ), # we start to move forward --> should go to next cell now Replay( position=(3, 3), @@ -457,7 +458,7 @@ def test_initial_malfunction_do_nothing(): action=RailEnvActions.MOVE_FORWARD, malfunction=0, reward=env.step_penalty * 1.0, # step penalty for speed 1.0 - state=TrainState.ACTIVE + state=TrainState.MOVING ) ], speed=env.agents[0].speed_counter.speed, @@ -546,7 +547,7 @@ def test_last_malfunction_step(): env.reset(False, False) for a_idx in range(len(env.agents)): env.agents[a_idx].position = env.agents[a_idx].initial_position - env.agents[a_idx].state = TrainState.ACTIVE + env.agents[a_idx].state = TrainState.MOVING # Force malfunction to be off at beginning and next malfunction to happen in 2 steps env.agents[0].malfunction_data['next_malfunction'] = 2 env.agents[0].malfunction_data['malfunction'] = 0 diff --git a/tests/test_utils.py b/tests/test_utils.py index 3469c9c8..391ba535 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -107,7 +107,6 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: for a, test_config in enumerate(test_configs): agent: EnvAgent = env.agents[a] replay = test_config.replay[step] - print(agent.position, replay.position, agent.state, agent.speed_counter) _assert(a, agent.position, replay.position, 'position') _assert(a, agent.direction, replay.direction, 'direction') if replay.state is not None: -- GitLab