Commit 94feb039 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

position update allowed on cell exit and stopped state

parent 1d63feb8
...@@ -431,7 +431,7 @@ class RailEnv(Environment): ...@@ -431,7 +431,7 @@ class RailEnv(Environment):
* Block all actions when in waiting state * Block all actions when in waiting state
* Check MOVE_LEFT/MOVE_RIGHT actions on current position else try MOVE_FORWARD * Check MOVE_LEFT/MOVE_RIGHT actions on current position else try MOVE_FORWARD
""" """
action = action_preprocessing.preprocess_raw_action(action, agent.state) action = action_preprocessing.preprocess_raw_action(action, agent.state, agent.action_saver.saved_action)
action = action_preprocessing.preprocess_action_when_waiting(action, agent.state) action = action_preprocessing.preprocess_action_when_waiting(action, agent.state)
# Try moving actions on current position # Try moving actions on current position
...@@ -440,7 +440,6 @@ class RailEnv(Environment): ...@@ -440,7 +440,6 @@ class RailEnv(Environment):
current_position, current_direction = agent.initial_position, agent.initial_direction current_position, current_direction = agent.initial_position, agent.initial_direction
action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction) action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction)
return action return action
def clear_rewards_dict(self): def clear_rewards_dict(self):
...@@ -513,6 +512,9 @@ class RailEnv(Environment): ...@@ -513,6 +512,9 @@ class RailEnv(Environment):
# Save moving actions in not already saved # Save moving actions in not already saved
agent.action_saver.save_action_if_allowed(preprocessed_action, agent.state) agent.action_saver.save_action_if_allowed(preprocessed_action, agent.state)
# Train's next position can change if current stopped in a fractional speed or train is at cell's exit
position_update_allowed = (agent.speed_counter.is_cell_exit or agent.state == TrainState.STOPPED)
# Calculate new position # Calculate new position
# Add agent to the map if not on it yet # Add agent to the map if not on it yet
if agent.position is None and agent.action_saver.is_action_saved: if agent.position is None and agent.action_saver.is_action_saved:
...@@ -520,7 +522,7 @@ class RailEnv(Environment): ...@@ -520,7 +522,7 @@ class RailEnv(Environment):
new_direction = agent.initial_direction new_direction = agent.initial_direction
# If movement is allowed apply saved action independent of other agents # If movement is allowed apply saved action independent of other agents
elif agent.action_saver.is_action_saved: elif agent.action_saver.is_action_saved and position_update_allowed:
saved_action = agent.action_saver.saved_action saved_action = agent.action_saver.saved_action
# Apply action independent of other agents and get temporary new position and direction # Apply action independent of other agents and get temporary new position and direction
new_position, new_direction = self.apply_action_independent(saved_action, new_position, new_direction = self.apply_action_independent(saved_action,
...@@ -557,7 +559,7 @@ class RailEnv(Environment): ...@@ -557,7 +559,7 @@ class RailEnv(Environment):
(agent.speed_counter.is_cell_exit or agent.position is None): (agent.speed_counter.is_cell_exit or agent.position is None):
agent.position = agent_transition_data.position agent.position = agent_transition_data.position
agent.direction = agent_transition_data.direction agent.direction = agent_transition_data.direction
preprocessed_action = agent_transition_data.preprocessed_action preprocessed_action = agent_transition_data.preprocessed_action
## Update states ## Update states
...@@ -565,9 +567,8 @@ class RailEnv(Environment): ...@@ -565,9 +567,8 @@ class RailEnv(Environment):
agent.state_machine.set_transition_signals(state_transition_signals) agent.state_machine.set_transition_signals(state_transition_signals)
agent.state_machine.step() agent.state_machine.step()
if agent.state.is_on_map_state() and agent.position is None: # Off map or on map state and position should match
raise ValueError("Agent ID {} Agent State {} not matching with Agent Position {} ".format( state_position_sync_check(agent.state, agent.position, agent.handle)
agent.handle, str(agent.state), str(agent.position) ))
# Handle done state actions, optionally remove agents # Handle done state actions, optionally remove agents
self.handle_done_state(agent) self.handle_done_state(agent)
...@@ -583,7 +584,7 @@ class RailEnv(Environment): ...@@ -583,7 +584,7 @@ class RailEnv(Environment):
agent.malfunction_handler.update_counter() agent.malfunction_handler.update_counter()
# Clear old action when starting in new cell # Clear old action when starting in new cell
if agent.speed_counter.is_cell_entry: if agent.speed_counter.is_cell_entry and agent.position is not None:
agent.action_saver.clear_saved_action() agent.action_saver.clear_saved_action()
# Check if episode has ended and update rewards and dones # Check if episode has ended and update rewards and dones
...@@ -687,3 +688,11 @@ def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool: ...@@ -687,3 +688,11 @@ def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
return False return False
else: else:
return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1] return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
def state_position_sync_check(state, position, i_agent):
if state.is_on_map_state() and position is None:
raise ValueError("Agent ID {} Agent State {} is on map Agent Position {} if off map ".format(
i_agent, str(state), str(position) ))
elif state.is_off_map_state() and position is not None:
raise ValueError("Agent ID {} Agent State {} is off map Agent Position {} if on map ".format(
i_agent, str(state), str(position) ))
...@@ -11,9 +11,11 @@ def process_illegal_action(action: RailEnvActions): ...@@ -11,9 +11,11 @@ def process_illegal_action(action: RailEnvActions):
return RailEnvActions(action) return RailEnvActions(action)
def process_do_nothing(state: TrainState): def process_do_nothing(state: TrainState, saved_action: RailEnvActions):
if state == TrainState.MOVING: if state == TrainState.MOVING:
action = RailEnvActions.MOVE_FORWARD action = RailEnvActions.MOVE_FORWARD
elif saved_action:
action = saved_action
else: else:
action = RailEnvActions.STOP_MOVING action = RailEnvActions.STOP_MOVING
return action return action
...@@ -34,7 +36,7 @@ def preprocess_action_when_waiting(action, state): ...@@ -34,7 +36,7 @@ def preprocess_action_when_waiting(action, state):
return action return action
def preprocess_raw_action(action, state): def preprocess_raw_action(action, state, saved_action):
""" """
Preprocesses actions to handle different situations of usage of action based on context Preprocesses actions to handle different situations of usage of action based on context
- DO_NOTHING is converted to FORWARD if train is moving - DO_NOTHING is converted to FORWARD if train is moving
...@@ -43,7 +45,7 @@ def preprocess_raw_action(action, state): ...@@ -43,7 +45,7 @@ def preprocess_raw_action(action, state):
action = process_illegal_action(action) action = process_illegal_action(action)
if action == RailEnvActions.DO_NOTHING: if action == RailEnvActions.DO_NOTHING:
action = process_do_nothing(state) action = process_do_nothing(state, saved_action)
return action return action
...@@ -55,6 +57,4 @@ def preprocess_moving_action(action, rail, position, direction): ...@@ -55,6 +57,4 @@ def preprocess_moving_action(action, rail, position, direction):
if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]: if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]:
action = process_left_right(action, rail, position, direction) action = process_left_right(action, rail, position, direction)
if not check_valid_action(action, rail, position, direction):
action = RailEnvActions.STOP_MOVING
return action return action
\ No newline at end of file
...@@ -10,6 +10,7 @@ def get_number_of_steps_to_break(malfunction_generator, np_random): ...@@ -10,6 +10,7 @@ def get_number_of_steps_to_break(malfunction_generator, np_random):
class MalfunctionHandler: class MalfunctionHandler:
def __init__(self): def __init__(self):
self._malfunction_down_counter = 0 self._malfunction_down_counter = 0
self.num_malfunctions = 0
@property @property
def in_malfunction(self): def in_malfunction(self):
...@@ -33,6 +34,7 @@ class MalfunctionHandler: ...@@ -33,6 +34,7 @@ class MalfunctionHandler:
# Only set new malfunction value if old malfunction is completed # Only set new malfunction value if old malfunction is completed
if self._malfunction_down_counter == 0: if self._malfunction_down_counter == 0:
self._malfunction_down_counter = val self._malfunction_down_counter = val
self.num_malfunctions += 1
def generate_malfunction(self, malfunction_generator, np_random): def generate_malfunction(self, malfunction_generator, np_random):
num_broken_steps = get_number_of_steps_to_break(malfunction_generator, np_random) num_broken_steps = get_number_of_steps_to_break(malfunction_generator, np_random)
...@@ -44,16 +46,20 @@ class MalfunctionHandler: ...@@ -44,16 +46,20 @@ class MalfunctionHandler:
def __repr__(self): def __repr__(self):
return f"malfunction_down_counter: {self._malfunction_down_counter} \ return f"malfunction_down_counter: {self._malfunction_down_counter} \
in_malfunction: {self.in_malfunction}" in_malfunction: {self.in_malfunction} \
num_malfunctions: {self.num_malfunctions}"
def to_dict(self): def to_dict(self):
return {"malfunction_down_counter": self._malfunction_down_counter} return {"malfunction_down_counter": self._malfunction_down_counter,
"num_malfunctions": self.num_malfunctions}
def from_dict(self, load_dict): def from_dict(self, load_dict):
self._malfunction_down_counter = load_dict['malfunction_down_counter'] self._malfunction_down_counter = load_dict['malfunction_down_counter']
self.num_malfunctions = load_dict['num_malfunctions']
def __eq__(self, other): def __eq__(self, other):
return self._malfunction_down_counter == other._malfunction_down_counter return self._malfunction_down_counter == other._malfunction_down_counter and \
self.num_malfunctions == other.num_malfunctions
......
...@@ -31,12 +31,21 @@ class TrainStateMachine: ...@@ -31,12 +31,21 @@ class TrainStateMachine:
def _handle_malfunction_off_map(self): def _handle_malfunction_off_map(self):
if self.st_signals.malfunction_counter_complete: if self.st_signals.malfunction_counter_complete:
if self.st_signals.earliest_departure_reached: if self.st_signals.earliest_departure_reached:
self.next_state = TrainState.READY_TO_DEPART
if self.st_signals.valid_movement_action_given:
self.next_state = TrainState.MOVING
elif self.st_signals.stop_action_given:
self.next_state = TrainState.STOPPED
else:
self.next_state = TrainState.READY_TO_DEPART
else: else:
self.next_state = TrainState.STOPPED self.next_state = TrainState.WAITING
else: else:
self.next_state = TrainState.WAITING self.next_state = TrainState.MALFUNCTION_OFF_MAP
def _handle_moving(self): def _handle_moving(self):
if self.st_signals.in_malfunction: if self.st_signals.in_malfunction:
...@@ -61,7 +70,7 @@ class TrainStateMachine: ...@@ -61,7 +70,7 @@ class TrainStateMachine:
self.st_signals.valid_movement_action_given: self.st_signals.valid_movement_action_given:
self.next_state = TrainState.MOVING self.next_state = TrainState.MOVING
elif self.st_signals.malfunction_counter_complete and \ elif self.st_signals.malfunction_counter_complete and \
(self.st_signals.stop_action_given or self.st_signals.movement_conflict): (self.st_signals.stop_action_given or self.st_signals.movement_conflict):
self.next_state = TrainState.STOPPED self.next_state = TrainState.STOPPED
else: else:
self.next_state = TrainState.MALFUNCTION self.next_state = TrainState.MALFUNCTION
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment