Skip to content
Snippets Groups Projects
Commit 8216361d authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

fixes for ready to depart and malfunction

parent 1ed22efa
No related branches found
No related tags found
No related merge requests found
...@@ -531,14 +531,17 @@ class RailEnv(Environment): ...@@ -531,14 +531,17 @@ class RailEnv(Environment):
agent.action_saver.save_action_if_allowed(preprocessed_action, agent.state) agent.action_saver.save_action_if_allowed(preprocessed_action, agent.state)
# Train's next position can change if current stopped in a fractional speed or train is at cell's exit # Train's next position can change if current stopped in a fractional speed or train is at cell's exit
position_update_allowed = (agent.speed_counter.is_cell_exit or agent.state == TrainState.STOPPED) position_update_allowed = (agent.speed_counter.is_cell_exit or agent.state == TrainState.STOPPED) and \
not agent.state.is_malfunction_state()
# Calculate new position # Calculate new position
# Keep agent in same place if already done
if agent.state == TrainState.DONE:
new_position, new_direction = agent.position, agent.direction
# Add agent to the map if not on it yet # Add agent to the map if not on it yet
if agent.position is None and agent.action_saver.is_action_saved: elif agent.position is None and agent.action_saver.is_action_saved:
new_position = agent.initial_position new_position = agent.initial_position
new_direction = agent.initial_direction new_direction = agent.initial_direction
# If movement is allowed apply saved action independent of other agents # If movement is allowed apply saved action independent of other agents
elif agent.action_saver.is_action_saved and position_update_allowed: elif agent.action_saver.is_action_saved and position_update_allowed:
saved_action = agent.action_saver.saved_action saved_action = agent.action_saver.saved_action
...@@ -554,7 +557,7 @@ class RailEnv(Environment): ...@@ -554,7 +557,7 @@ class RailEnv(Environment):
temp_transition_data[i_agent] = env_utils.AgentTransitionData(position=new_position, temp_transition_data[i_agent] = env_utils.AgentTransitionData(position=new_position,
direction=new_direction, direction=new_direction,
preprocessed_action=preprocessed_action) preprocessed_action=preprocessed_action)
# This is for storing and later checking for conflicts of agents trying to occupy same cell # This is for storing and later checking for conflicts of agents trying to occupy same cell
self.motionCheck.addAgent(i_agent, agent.position, new_position) self.motionCheck.addAgent(i_agent, agent.position, new_position)
...@@ -570,8 +573,6 @@ class RailEnv(Environment): ...@@ -570,8 +573,6 @@ class RailEnv(Environment):
else: else:
movement_allowed = self.motionCheck.check_motion(i_agent, agent.position) movement_allowed = self.motionCheck.check_motion(i_agent, agent.position)
# Fetch the saved transition data # Fetch the saved transition data
agent_transition_data = temp_transition_data[i_agent] agent_transition_data = temp_transition_data[i_agent]
preprocessed_action = agent_transition_data.preprocessed_action preprocessed_action = agent_transition_data.preprocessed_action
...@@ -618,8 +619,15 @@ class RailEnv(Environment): ...@@ -618,8 +619,15 @@ class RailEnv(Environment):
# Check if episode has ended and update rewards and dones # Check if episode has ended and update rewards and dones
self.end_of_episode_update(have_all_agents_ended) self.end_of_episode_update(have_all_agents_ended)
old_agent_positions = self.agent_positions.copy()
self._update_agent_positions_map() self._update_agent_positions_map()
for ag in self.agents:
if ag.state == TrainState.READY_TO_DEPART and action_dict_.get(ag.handle, 0) in [1, 2, 3] and \
self.agent_positions[ag.initial_position] == -1 and ag.state_machine.previous_state == TrainState.READY_TO_DEPART:
print(old_agent_positions[ag.initial_position])
import pdb; pdb.set_trace()
return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict() return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict()
def record_timestep(self, dActions): def record_timestep(self, dActions):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment