Commit d4667187 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

Fix for stopped to moving in fractional speeds

parent 9301e30b
......@@ -224,7 +224,7 @@ class EnvAgent:
old_position: {self.old_position} old_direction {self.old_direction} \n \
earliest_departure: {self.earliest_departure} latest_arrival: {self.latest_arrival} \n \
state: {str(self.state)} \n \
malfunction_data: {self.malfunction_data} \n \
malfunction_handler: {self.malfunction_handler} \n \
action_saver: {self.action_saver} \n \
speed_counter: {self.speed_counter}"
......
......@@ -373,7 +373,7 @@ class RailEnv(Environment):
st_signals = StateTransitionSignals()
# Malfunction starts when in_malfunction is set to true
st_signals.malfunction_onset = agent.malfunction_handler.in_malfunction
st_signals.in_malfunction = agent.malfunction_handler.in_malfunction
# Malfunction counter complete - Malfunction ends next timestep
st_signals.malfunction_counter_complete = agent.malfunction_handler.malfunction_counter_complete
......@@ -519,8 +519,8 @@ class RailEnv(Environment):
new_position = agent.initial_position
new_direction = agent.initial_direction
# When cell exit occurs apply saved action independent of other agents
elif agent.speed_counter.is_cell_exit and agent.action_saver.is_action_saved:
# If movement is allowed apply saved action independent of other agents
elif agent.action_saver.is_action_saved:
saved_action = agent.action_saver.saved_action
# Apply action independent of other agents and get temporary new position and direction
new_position, new_direction = self.apply_action_independent(saved_action,
......@@ -551,7 +551,10 @@ class RailEnv(Environment):
else:
movement_allowed = self.motionCheck.check_motion(i_agent, agent.position) # TODO: Remove final_new_postion from motioncheck
if movement_allowed:
# Position can be changed only if other cell is empty
# And either the speed counter completes or agent is being added to map
if movement_allowed and \
(agent.speed_counter.is_cell_exit or agent.position is None):
agent.position = agent_transition_data.position
agent.direction = agent_transition_data.direction
......@@ -576,6 +579,7 @@ class RailEnv(Environment):
## Update counters (malfunction and speed)
agent.speed_counter.update_counter(agent.state, agent.old_position)
# agent.state_machine.previous_state)
agent.malfunction_handler.update_counter()
# Clear old action when starting in new cell
......
......@@ -10,10 +10,17 @@ class ActionSaver:
return self.saved_action is not None
def __repr__(self):
return f"is_action_saved: {self.is_action_saved}, saved_action: {self.saved_action}"
return f"is_action_saved: {self.is_action_saved}, saved_action: {str(self.saved_action)}"
def save_action_if_allowed(self, action, state):
"""
Save the action if all conditions are met
1. It is a movement based action -> Forward, Left, Right
2. Action is not already saved
3. Not in a malfunction state
4. Agent is not already done
"""
if action.is_moving_action() and \
not self.is_action_saved and \
not state.is_malfunction_state() and \
......
......@@ -8,11 +8,13 @@ class SpeedCounter:
self.reset_counter()
def update_counter(self, state, old_position):
# When coming onto the map, do no update speed counter
# Can't start counting when adding train to the map
if state == TrainState.MOVING and old_position is not None:
self.counter += 1
self.counter = self.counter % (self.max_count + 1)
def __repr__(self):
return f"speed: {self.speed} \
max_count: {self.max_count} \
......
......@@ -406,26 +406,26 @@ def test_multispeed_actions_malfunction_no_blocking():
set_penalties_for_replay(env)
test_config = ReplayConfig(
replay=[
Replay(
Replay( # 0
position=(3, 9), # east dead-end
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.start_penalty + env.step_penalty * 0.5 # starting and running at speed 0.5
),
Replay(
Replay( # 1
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
Replay( # 2
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
# add additional step in the cell
Replay(
Replay( # 3
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None,
......@@ -434,26 +434,26 @@ def test_multispeed_actions_malfunction_no_blocking():
reward=env.step_penalty * 0.5 # step penalty for speed 0.5 when malfunctioning
),
# agent recovers in this step
Replay(
Replay( # 4
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None,
malfunction=1,
reward=env.step_penalty * 0.5 # recovered: running at speed 0.5
),
Replay(
Replay( # 5
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
Replay( # 6
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
Replay( # 7
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None,
......@@ -462,57 +462,57 @@ def test_multispeed_actions_malfunction_no_blocking():
reward=env.step_penalty * 0.5 # step penalty for speed 0.5 when malfunctioning
),
# agent recovers in this step; since we're at the beginning, we provide a different action although we're broken!
Replay(
Replay( # 8
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None,
malfunction=1,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
Replay( # 9
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
Replay( # 10
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.STOP_MOVING,
reward=env.stop_penalty + env.step_penalty * 0.5 # stopping and step penalty for speed 0.5
),
Replay(
Replay( # 11
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.STOP_MOVING,
reward=env.step_penalty * 0.5 # step penalty for speed 0.5 while stopped
),
Replay(
Replay( # 12
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.start_penalty + env.step_penalty * 0.5 # starting and running at speed 0.5
),
Replay(
Replay( # 13
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
# DO_NOTHING keeps moving!
Replay(
Replay( # 14
position=(3, 5),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.DO_NOTHING,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
Replay( # 15
position=(3, 5),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5 # running at speed 0.5
),
Replay(
Replay( # 16
position=(3, 4),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment