Commit cc04d63f authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

malfunction fix: can save actions during malfunction

parent d4667187
Pipeline #8481 failed with stages
in 4 minutes and 39 seconds
...@@ -253,7 +253,7 @@ def malfunction_from_params(parameters: MalfunctionParameters) -> Tuple[Malfunct ...@@ -253,7 +253,7 @@ def malfunction_from_params(parameters: MalfunctionParameters) -> Tuple[Malfunct
min_number_of_steps_broken = parameters.min_duration min_number_of_steps_broken = parameters.min_duration
max_number_of_steps_broken = parameters.max_duration max_number_of_steps_broken = parameters.max_duration
def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]: def generator(np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
""" """
Generate malfunctions for agents Generate malfunctions for agents
Parameters Parameters
...@@ -270,11 +270,10 @@ def malfunction_from_params(parameters: MalfunctionParameters) -> Tuple[Malfunct ...@@ -270,11 +270,10 @@ def malfunction_from_params(parameters: MalfunctionParameters) -> Tuple[Malfunct
if reset: if reset:
return Malfunction(0) return Malfunction(0)
if agent.malfunction_data['malfunction'] < 1: if np_random.rand() < _malfunction_prob(mean_malfunction_rate):
if np_random.rand() < _malfunction_prob(mean_malfunction_rate): num_broken_steps = np_random.randint(min_number_of_steps_broken,
num_broken_steps = np_random.randint(min_number_of_steps_broken, max_number_of_steps_broken + 1)
max_number_of_steps_broken + 1) + 1 return Malfunction(num_broken_steps)
return Malfunction(num_broken_steps)
return Malfunction(0) return Malfunction(0)
return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken, return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
......
...@@ -549,7 +549,7 @@ class RailEnv(Environment): ...@@ -549,7 +549,7 @@ class RailEnv(Environment):
if agent.malfunction_handler.in_malfunction: if agent.malfunction_handler.in_malfunction:
movement_allowed = False movement_allowed = False
else: else:
movement_allowed = self.motionCheck.check_motion(i_agent, agent.position) # TODO: Remove final_new_postion from motioncheck movement_allowed = self.motionCheck.check_motion(i_agent, agent.position)
# Position can be changed only if other cell is empty # Position can be changed only if other cell is empty
# And either the speed counter completes or agent is being added to map # And either the speed counter completes or agent is being added to map
......
...@@ -17,14 +17,10 @@ class ActionSaver: ...@@ -17,14 +17,10 @@ class ActionSaver:
""" """
Save the action if all conditions are met Save the action if all conditions are met
1. It is a movement based action -> Forward, Left, Right 1. It is a movement based action -> Forward, Left, Right
2. Action is not already saved 2. Action is not already saved
3. Not in a malfunction state 3. Agent is not already done
4. Agent is not already done
""" """
if action.is_moving_action() and \ if action.is_moving_action() and not self.is_action_saved and not state == TrainState.DONE:
not self.is_action_saved and \
not state.is_malfunction_state() and \
not state == TrainState.DONE:
self.saved_action = action self.saved_action = action
def clear_saved_action(self): def clear_saved_action(self):
......
...@@ -190,8 +190,9 @@ def test_malfunction_before_entry(): ...@@ -190,8 +190,9 @@ def test_malfunction_before_entry():
# Test initial malfunction values for all agents # Test initial malfunction values for all agents
# we want some agents to be malfuncitoning already and some to be working # we want some agents to be malfuncitoning already and some to be working
# we want different next_malfunction values for the agents # we want different next_malfunction values for the agents
assert env.agents[0].malfunction_data['malfunction'] == 0 malfunction_values = [env.malfunction_generator(env.np_random).num_broken_steps for _ in range(1000)]
assert env.agents[1].malfunction_data['malfunction'] == 10 expected_value = (1 - np.exp(-0.5)) * 10
assert np.allclose(np.mean(malfunction_values), expected_value, rtol=0.1), "Mean values of malfunction don't match rate"
def test_malfunction_values_and_behavior(): def test_malfunction_values_and_behavior():
...@@ -257,7 +258,7 @@ def test_initial_malfunction(): ...@@ -257,7 +258,7 @@ def test_initial_malfunction():
set_penalties_for_replay(env) set_penalties_for_replay(env)
replay_config = ReplayConfig( replay_config = ReplayConfig(
replay=[ replay=[
Replay( Replay( # 0
position=(3, 2), position=(3, 2),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD, action=RailEnvActions.MOVE_FORWARD,
...@@ -265,7 +266,7 @@ def test_initial_malfunction(): ...@@ -265,7 +266,7 @@ def test_initial_malfunction():
malfunction=3, malfunction=3,
reward=env.step_penalty # full step penalty when malfunctioning reward=env.step_penalty # full step penalty when malfunctioning
), ),
Replay( Replay( # 1
position=(3, 2), position=(3, 2),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD, action=RailEnvActions.MOVE_FORWARD,
...@@ -274,7 +275,7 @@ def test_initial_malfunction(): ...@@ -274,7 +275,7 @@ def test_initial_malfunction():
), ),
# malfunction stops in the next step and we're still at the beginning of the cell # malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action MOVE_FORWARD, agent should restart and move to the next cell # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
Replay( Replay( # 2
position=(3, 2), position=(3, 2),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD, action=RailEnvActions.MOVE_FORWARD,
...@@ -282,14 +283,14 @@ def test_initial_malfunction(): ...@@ -282,14 +283,14 @@ def test_initial_malfunction():
reward=env.step_penalty reward=env.step_penalty
), # malfunctioning ends: starting and running at speed 1.0 ), # malfunctioning ends: starting and running at speed 1.0
Replay( Replay( # 3
position=(3, 2), position=(3, 2),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD, action=RailEnvActions.MOVE_FORWARD,
malfunction=0, malfunction=0,
reward=env.start_penalty + env.step_penalty * 1.0 # running at speed 1.0 reward=env.start_penalty + env.step_penalty * 1.0 # running at speed 1.0
), ),
Replay( Replay( # 4
position=(3, 3), position=(3, 3),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD, action=RailEnvActions.MOVE_FORWARD,
...@@ -420,7 +421,7 @@ def test_initial_malfunction_do_nothing(): ...@@ -420,7 +421,7 @@ def test_initial_malfunction_do_nothing():
action=RailEnvActions.DO_NOTHING, action=RailEnvActions.DO_NOTHING,
malfunction=2, malfunction=2,
reward=env.step_penalty, # full step penalty while malfunctioning reward=env.step_penalty, # full step penalty while malfunctioning
state=TrainState.ACTIVE state=TrainState.MOVING
), ),
# malfunction stops in the next step and we're still at the beginning of the cell # malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action DO_NOTHING, agent should restart without moving # --> if we take action DO_NOTHING, agent should restart without moving
...@@ -431,7 +432,7 @@ def test_initial_malfunction_do_nothing(): ...@@ -431,7 +432,7 @@ def test_initial_malfunction_do_nothing():
action=RailEnvActions.DO_NOTHING, action=RailEnvActions.DO_NOTHING,
malfunction=1, malfunction=1,
reward=env.step_penalty, # full step penalty while stopped reward=env.step_penalty, # full step penalty while stopped
state=TrainState.ACTIVE state=TrainState.MOVING
), ),
# we haven't started moving yet --> stay here # we haven't started moving yet --> stay here
Replay( Replay(
...@@ -440,7 +441,7 @@ def test_initial_malfunction_do_nothing(): ...@@ -440,7 +441,7 @@ def test_initial_malfunction_do_nothing():
action=RailEnvActions.DO_NOTHING, action=RailEnvActions.DO_NOTHING,
malfunction=0, malfunction=0,
reward=env.step_penalty, # full step penalty while stopped reward=env.step_penalty, # full step penalty while stopped
state=TrainState.ACTIVE state=TrainState.MOVING
), ),
Replay( Replay(
...@@ -449,7 +450,7 @@ def test_initial_malfunction_do_nothing(): ...@@ -449,7 +450,7 @@ def test_initial_malfunction_do_nothing():
action=RailEnvActions.MOVE_FORWARD, action=RailEnvActions.MOVE_FORWARD,
malfunction=0, malfunction=0,
reward=env.start_penalty + env.step_penalty * 1.0, # start penalty + step penalty for speed 1.0 reward=env.start_penalty + env.step_penalty * 1.0, # start penalty + step penalty for speed 1.0
state=TrainState.ACTIVE state=TrainState.MOVING
), # we start to move forward --> should go to next cell now ), # we start to move forward --> should go to next cell now
Replay( Replay(
position=(3, 3), position=(3, 3),
...@@ -457,7 +458,7 @@ def test_initial_malfunction_do_nothing(): ...@@ -457,7 +458,7 @@ def test_initial_malfunction_do_nothing():
action=RailEnvActions.MOVE_FORWARD, action=RailEnvActions.MOVE_FORWARD,
malfunction=0, malfunction=0,
reward=env.step_penalty * 1.0, # step penalty for speed 1.0 reward=env.step_penalty * 1.0, # step penalty for speed 1.0
state=TrainState.ACTIVE state=TrainState.MOVING
) )
], ],
speed=env.agents[0].speed_counter.speed, speed=env.agents[0].speed_counter.speed,
...@@ -546,7 +547,7 @@ def test_last_malfunction_step(): ...@@ -546,7 +547,7 @@ def test_last_malfunction_step():
env.reset(False, False) env.reset(False, False)
for a_idx in range(len(env.agents)): for a_idx in range(len(env.agents)):
env.agents[a_idx].position = env.agents[a_idx].initial_position env.agents[a_idx].position = env.agents[a_idx].initial_position
env.agents[a_idx].state = TrainState.ACTIVE env.agents[a_idx].state = TrainState.MOVING
# Force malfunction to be off at beginning and next malfunction to happen in 2 steps # Force malfunction to be off at beginning and next malfunction to happen in 2 steps
env.agents[0].malfunction_data['next_malfunction'] = 2 env.agents[0].malfunction_data['next_malfunction'] = 2
env.agents[0].malfunction_data['malfunction'] = 0 env.agents[0].malfunction_data['malfunction'] = 0
......
...@@ -107,7 +107,6 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: ...@@ -107,7 +107,6 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
for a, test_config in enumerate(test_configs): for a, test_config in enumerate(test_configs):
agent: EnvAgent = env.agents[a] agent: EnvAgent = env.agents[a]
replay = test_config.replay[step] replay = test_config.replay[step]
print(agent.position, replay.position, agent.state, agent.speed_counter)
_assert(a, agent.position, replay.position, 'position') _assert(a, agent.position, replay.position, 'position')
_assert(a, agent.direction, replay.direction, 'direction') _assert(a, agent.direction, replay.direction, 'direction')
if replay.state is not None: if replay.state is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment