Skip to content
Snippets Groups Projects
Commit 94feb039 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

position update allowed on cell exit and stopped state

parent 1d63feb8
No related branches found
No related tags found
No related merge requests found
......@@ -431,7 +431,7 @@ class RailEnv(Environment):
* Block all actions when in waiting state
* Check MOVE_LEFT/MOVE_RIGHT actions on current position else try MOVE_FORWARD
"""
action = action_preprocessing.preprocess_raw_action(action, agent.state)
action = action_preprocessing.preprocess_raw_action(action, agent.state, agent.action_saver.saved_action)
action = action_preprocessing.preprocess_action_when_waiting(action, agent.state)
# Try moving actions on current position
......@@ -440,7 +440,6 @@ class RailEnv(Environment):
current_position, current_direction = agent.initial_position, agent.initial_direction
action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction)
return action
def clear_rewards_dict(self):
......@@ -513,6 +512,9 @@ class RailEnv(Environment):
# Save moving actions in not already saved
agent.action_saver.save_action_if_allowed(preprocessed_action, agent.state)
# Train's next position can change if current stopped in a fractional speed or train is at cell's exit
position_update_allowed = (agent.speed_counter.is_cell_exit or agent.state == TrainState.STOPPED)
# Calculate new position
# Add agent to the map if not on it yet
if agent.position is None and agent.action_saver.is_action_saved:
......@@ -520,7 +522,7 @@ class RailEnv(Environment):
new_direction = agent.initial_direction
# If movement is allowed apply saved action independent of other agents
elif agent.action_saver.is_action_saved:
elif agent.action_saver.is_action_saved and position_update_allowed:
saved_action = agent.action_saver.saved_action
# Apply action independent of other agents and get temporary new position and direction
new_position, new_direction = self.apply_action_independent(saved_action,
......@@ -557,7 +559,7 @@ class RailEnv(Environment):
(agent.speed_counter.is_cell_exit or agent.position is None):
agent.position = agent_transition_data.position
agent.direction = agent_transition_data.direction
preprocessed_action = agent_transition_data.preprocessed_action
## Update states
......@@ -565,9 +567,8 @@ class RailEnv(Environment):
agent.state_machine.set_transition_signals(state_transition_signals)
agent.state_machine.step()
if agent.state.is_on_map_state() and agent.position is None:
raise ValueError("Agent ID {} Agent State {} not matching with Agent Position {} ".format(
agent.handle, str(agent.state), str(agent.position) ))
# Off map or on map state and position should match
state_position_sync_check(agent.state, agent.position, agent.handle)
# Handle done state actions, optionally remove agents
self.handle_done_state(agent)
......@@ -583,7 +584,7 @@ class RailEnv(Environment):
agent.malfunction_handler.update_counter()
# Clear old action when starting in new cell
if agent.speed_counter.is_cell_entry:
if agent.speed_counter.is_cell_entry and agent.position is not None:
agent.action_saver.clear_saved_action()
# Check if episode has ended and update rewards and dones
......@@ -687,3 +688,11 @@ def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
return False
else:
return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
def state_position_sync_check(state, position, i_agent):
if state.is_on_map_state() and position is None:
raise ValueError("Agent ID {} Agent State {} is on map Agent Position {} if off map ".format(
i_agent, str(state), str(position) ))
elif state.is_off_map_state() and position is not None:
raise ValueError("Agent ID {} Agent State {} is off map Agent Position {} if on map ".format(
i_agent, str(state), str(position) ))
......@@ -11,9 +11,11 @@ def process_illegal_action(action: RailEnvActions):
return RailEnvActions(action)
def process_do_nothing(state: TrainState):
def process_do_nothing(state: TrainState, saved_action: RailEnvActions):
if state == TrainState.MOVING:
action = RailEnvActions.MOVE_FORWARD
elif saved_action:
action = saved_action
else:
action = RailEnvActions.STOP_MOVING
return action
......@@ -34,7 +36,7 @@ def preprocess_action_when_waiting(action, state):
return action
def preprocess_raw_action(action, state):
def preprocess_raw_action(action, state, saved_action):
"""
Preprocesses actions to handle different situations of usage of action based on context
- DO_NOTHING is converted to FORWARD if train is moving
......@@ -43,7 +45,7 @@ def preprocess_raw_action(action, state):
action = process_illegal_action(action)
if action == RailEnvActions.DO_NOTHING:
action = process_do_nothing(state)
action = process_do_nothing(state, saved_action)
return action
......@@ -55,6 +57,4 @@ def preprocess_moving_action(action, rail, position, direction):
if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]:
action = process_left_right(action, rail, position, direction)
if not check_valid_action(action, rail, position, direction):
action = RailEnvActions.STOP_MOVING
return action
\ No newline at end of file
......@@ -10,6 +10,7 @@ def get_number_of_steps_to_break(malfunction_generator, np_random):
class MalfunctionHandler:
def __init__(self):
self._malfunction_down_counter = 0
self.num_malfunctions = 0
@property
def in_malfunction(self):
......@@ -33,6 +34,7 @@ class MalfunctionHandler:
# Only set new malfunction value if old malfunction is completed
if self._malfunction_down_counter == 0:
self._malfunction_down_counter = val
self.num_malfunctions += 1
def generate_malfunction(self, malfunction_generator, np_random):
num_broken_steps = get_number_of_steps_to_break(malfunction_generator, np_random)
......@@ -44,16 +46,20 @@ class MalfunctionHandler:
def __repr__(self):
return f"malfunction_down_counter: {self._malfunction_down_counter} \
in_malfunction: {self.in_malfunction}"
in_malfunction: {self.in_malfunction} \
num_malfunctions: {self.num_malfunctions}"
def to_dict(self):
return {"malfunction_down_counter": self._malfunction_down_counter}
return {"malfunction_down_counter": self._malfunction_down_counter,
"num_malfunctions": self.num_malfunctions}
def from_dict(self, load_dict):
self._malfunction_down_counter = load_dict['malfunction_down_counter']
self.num_malfunctions = load_dict['num_malfunctions']
def __eq__(self, other):
return self._malfunction_down_counter == other._malfunction_down_counter
return self._malfunction_down_counter == other._malfunction_down_counter and \
self.num_malfunctions == other.num_malfunctions
......
......@@ -31,12 +31,21 @@ class TrainStateMachine:
def _handle_malfunction_off_map(self):
if self.st_signals.malfunction_counter_complete:
if self.st_signals.earliest_departure_reached:
self.next_state = TrainState.READY_TO_DEPART
if self.st_signals.valid_movement_action_given:
self.next_state = TrainState.MOVING
elif self.st_signals.stop_action_given:
self.next_state = TrainState.STOPPED
else:
self.next_state = TrainState.READY_TO_DEPART
else:
self.next_state = TrainState.STOPPED
self.next_state = TrainState.WAITING
else:
self.next_state = TrainState.WAITING
self.next_state = TrainState.MALFUNCTION_OFF_MAP
def _handle_moving(self):
if self.st_signals.in_malfunction:
......@@ -61,7 +70,7 @@ class TrainStateMachine:
self.st_signals.valid_movement_action_given:
self.next_state = TrainState.MOVING
elif self.st_signals.malfunction_counter_complete and \
(self.st_signals.stop_action_given or self.st_signals.movement_conflict):
(self.st_signals.stop_action_given or self.st_signals.movement_conflict):
self.next_state = TrainState.STOPPED
else:
self.next_state = TrainState.MALFUNCTION
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment