Commit 94feb039 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

position update allowed on cell exit and stopped state

parent 1d63feb8
......@@ -431,7 +431,7 @@ class RailEnv(Environment):
* Block all actions when in waiting state
* Check MOVE_LEFT/MOVE_RIGHT actions on current position else try MOVE_FORWARD
"""
action = action_preprocessing.preprocess_raw_action(action, agent.state)
action = action_preprocessing.preprocess_raw_action(action, agent.state, agent.action_saver.saved_action)
action = action_preprocessing.preprocess_action_when_waiting(action, agent.state)
# Try moving actions on current position
......@@ -440,7 +440,6 @@ class RailEnv(Environment):
current_position, current_direction = agent.initial_position, agent.initial_direction
action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction)
return action
def clear_rewards_dict(self):
......@@ -513,6 +512,9 @@ class RailEnv(Environment):
# Save moving actions in not already saved
agent.action_saver.save_action_if_allowed(preprocessed_action, agent.state)
# Train's next position can change if current stopped in a fractional speed or train is at cell's exit
position_update_allowed = (agent.speed_counter.is_cell_exit or agent.state == TrainState.STOPPED)
# Calculate new position
# Add agent to the map if not on it yet
if agent.position is None and agent.action_saver.is_action_saved:
......@@ -520,7 +522,7 @@ class RailEnv(Environment):
new_direction = agent.initial_direction
# If movement is allowed apply saved action independent of other agents
elif agent.action_saver.is_action_saved:
elif agent.action_saver.is_action_saved and position_update_allowed:
saved_action = agent.action_saver.saved_action
# Apply action independent of other agents and get temporary new position and direction
new_position, new_direction = self.apply_action_independent(saved_action,
......@@ -557,7 +559,7 @@ class RailEnv(Environment):
(agent.speed_counter.is_cell_exit or agent.position is None):
agent.position = agent_transition_data.position
agent.direction = agent_transition_data.direction
preprocessed_action = agent_transition_data.preprocessed_action
## Update states
......@@ -565,9 +567,8 @@ class RailEnv(Environment):
agent.state_machine.set_transition_signals(state_transition_signals)
agent.state_machine.step()
if agent.state.is_on_map_state() and agent.position is None:
raise ValueError("Agent ID {} Agent State {} not matching with Agent Position {} ".format(
agent.handle, str(agent.state), str(agent.position) ))
# Off map or on map state and position should match
state_position_sync_check(agent.state, agent.position, agent.handle)
# Handle done state actions, optionally remove agents
self.handle_done_state(agent)
......@@ -583,7 +584,7 @@ class RailEnv(Environment):
agent.malfunction_handler.update_counter()
# Clear old action when starting in new cell
if agent.speed_counter.is_cell_entry:
if agent.speed_counter.is_cell_entry and agent.position is not None:
agent.action_saver.clear_saved_action()
# Check if episode has ended and update rewards and dones
......@@ -687,3 +688,11 @@ def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
return False
else:
return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
def state_position_sync_check(state, position, i_agent):
if state.is_on_map_state() and position is None:
raise ValueError("Agent ID {} Agent State {} is on map Agent Position {} if off map ".format(
i_agent, str(state), str(position) ))
elif state.is_off_map_state() and position is not None:
raise ValueError("Agent ID {} Agent State {} is off map Agent Position {} if on map ".format(
i_agent, str(state), str(position) ))
......@@ -11,9 +11,11 @@ def process_illegal_action(action: RailEnvActions):
return RailEnvActions(action)
def process_do_nothing(state: TrainState):
def process_do_nothing(state: TrainState, saved_action: RailEnvActions):
if state == TrainState.MOVING:
action = RailEnvActions.MOVE_FORWARD
elif saved_action:
action = saved_action
else:
action = RailEnvActions.STOP_MOVING
return action
......@@ -34,7 +36,7 @@ def preprocess_action_when_waiting(action, state):
return action
def preprocess_raw_action(action, state):
def preprocess_raw_action(action, state, saved_action):
"""
Preprocesses actions to handle different situations of usage of action based on context
- DO_NOTHING is converted to FORWARD if train is moving
......@@ -43,7 +45,7 @@ def preprocess_raw_action(action, state):
action = process_illegal_action(action)
if action == RailEnvActions.DO_NOTHING:
action = process_do_nothing(state)
action = process_do_nothing(state, saved_action)
return action
......@@ -55,6 +57,4 @@ def preprocess_moving_action(action, rail, position, direction):
if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]:
action = process_left_right(action, rail, position, direction)
if not check_valid_action(action, rail, position, direction):
action = RailEnvActions.STOP_MOVING
return action
\ No newline at end of file
......@@ -10,6 +10,7 @@ def get_number_of_steps_to_break(malfunction_generator, np_random):
class MalfunctionHandler:
def __init__(self):
self._malfunction_down_counter = 0
self.num_malfunctions = 0
@property
def in_malfunction(self):
......@@ -33,6 +34,7 @@ class MalfunctionHandler:
# Only set new malfunction value if old malfunction is completed
if self._malfunction_down_counter == 0:
self._malfunction_down_counter = val
self.num_malfunctions += 1
def generate_malfunction(self, malfunction_generator, np_random):
num_broken_steps = get_number_of_steps_to_break(malfunction_generator, np_random)
......@@ -44,16 +46,20 @@ class MalfunctionHandler:
def __repr__(self):
return f"malfunction_down_counter: {self._malfunction_down_counter} \
in_malfunction: {self.in_malfunction}"
in_malfunction: {self.in_malfunction} \
num_malfunctions: {self.num_malfunctions}"
def to_dict(self):
return {"malfunction_down_counter": self._malfunction_down_counter}
return {"malfunction_down_counter": self._malfunction_down_counter,
"num_malfunctions": self.num_malfunctions}
def from_dict(self, load_dict):
self._malfunction_down_counter = load_dict['malfunction_down_counter']
self.num_malfunctions = load_dict['num_malfunctions']
def __eq__(self, other):
return self._malfunction_down_counter == other._malfunction_down_counter
return self._malfunction_down_counter == other._malfunction_down_counter and \
self.num_malfunctions == other.num_malfunctions
......
......@@ -31,12 +31,21 @@ class TrainStateMachine:
def _handle_malfunction_off_map(self):
if self.st_signals.malfunction_counter_complete:
if self.st_signals.earliest_departure_reached:
self.next_state = TrainState.READY_TO_DEPART
if self.st_signals.valid_movement_action_given:
self.next_state = TrainState.MOVING
elif self.st_signals.stop_action_given:
self.next_state = TrainState.STOPPED
else:
self.next_state = TrainState.READY_TO_DEPART
else:
self.next_state = TrainState.STOPPED
self.next_state = TrainState.WAITING
else:
self.next_state = TrainState.WAITING
self.next_state = TrainState.MALFUNCTION_OFF_MAP
def _handle_moving(self):
if self.st_signals.in_malfunction:
......@@ -61,7 +70,7 @@ class TrainStateMachine:
self.st_signals.valid_movement_action_given:
self.next_state = TrainState.MOVING
elif self.st_signals.malfunction_counter_complete and \
(self.st_signals.stop_action_given or self.st_signals.movement_conflict):
(self.st_signals.stop_action_given or self.st_signals.movement_conflict):
self.next_state = TrainState.STOPPED
else:
self.next_state = TrainState.MALFUNCTION
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment