Commit 53d7dcc1 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

minor refactors

parent 0fa980ba
......@@ -443,6 +443,25 @@ class RailEnv(Environment):
return reward
def preprocess_action(self, action, agent):
"""
Preprocess the provided action
* Change to DO_NOTHING if illegal action
* Block all actions when in waiting state
* Check MOVE_LEFT/MOVE_RIGHT actions on current position else try MOVE_FORWARD
"""
action = action_preprocessing.preprocess_raw_action(action, agent.state)
action = action_preprocessing.preprocess_action_when_waiting(action, agent.state)
# Try moving actions on current position
current_position, current_direction = agent.position, agent.direction
if current_position is None: # Agent not added on map yet
current_position, current_direction = agent.initial_position, agent.initial_direction
action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction)
return action
def step(self, action_dict_: Dict[int, RailEnvActions]):
"""
Updates rewards for the agents at a step.
......@@ -475,59 +494,51 @@ class RailEnv(Environment):
"speed": {},
"status": {},
}
have_all_agents_ended = True # boolean flag to check if all agents are done
self.motionCheck = ac.MotionCheck() # reset the motion check
temp_transition_data = {}
for i_agent, agent in enumerate(self.agents): # TODO: Important - Do not use i_agent like this, use agent.handle if needed
for agent in self.agents:
i_agent = agent.handle
# Generate malfunction
agent.malfunction_handler.generate_malfunction(self.malfunction_generator, self.np_random)
# Get action for the agent
action = action_dict_.get(i_agent, RailEnvActions.DO_NOTHING)
# TODO: Add the bottom stuff to separate function(s)
# Preprocess action
action = action_preprocessing.preprocess_raw_action(action, agent.state)
action = action_preprocessing.preprocess_action_when_waiting(action, agent.state)
# Try moving actions on current position
current_position, current_direction = agent.position, agent.direction
agent_not_on_map = current_position is None
if agent_not_on_map: # Agent not added on map yet
current_position, current_direction = agent.initial_position, agent.initial_direction
action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction)
preprocessed_action = self.preprocess_action(action, agent)
# Save moving actions in not already saved
agent.action_saver.save_action_if_allowed(action, agent.state)
agent.action_saver.save_action_if_allowed(preprocessed_action, agent.state)
# Calculate new position
# Add agent to the map if not on it yet
if agent_not_on_map and agent.action_saver.is_action_saved:
if agent.position is None and agent.action_saver.is_action_saved:
new_position = agent.initial_position
new_direction = agent.initial_direction
preprocessed_action = action
# When cell exit occurs apply saved action independent of other agents
elif agent.speed_counter.is_cell_exit and agent.action_saver.is_action_saved:
saved_action = agent.action_saver.saved_action
# Apply action independent of other agents and get temporary new position and direction
pd = self.apply_action_independent(saved_action, self.rail, agent.position, agent.direction)
new_position, new_direction = pd
new_position, new_direction = self.apply_action_independent(saved_action,
self.rail,
agent.position,
agent.direction)
preprocessed_action = saved_action
else:
new_position, new_direction = agent.position, agent.direction
preprocessed_action = action
temp_transition_data[i_agent] = AgentTransitionData(position=new_position,
direction=new_direction,
preprocessed_action=preprocessed_action)
# This is for checking conflicts of agents trying to occupy same cell
self.motionCheck.addAgent(i_agent, agent.position, new_position)
# Find conflicts
# TODO : Important - Modify conflicted positions and select one of them randomly to go to new position
self.motionCheck.find_conflicts()
for agent in self.agents:
......
......@@ -55,6 +55,6 @@ def preprocess_moving_action(action, rail, position, direction):
if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]:
action = process_left_right(action, rail, position, direction)
if not check_valid_action(action, rail, position, direction): # TODO: Dipam - Not sure if this is needed
if not check_valid_action(action, rail, position, direction):
action = RailEnvActions.STOP_MOVING
return action
\ No newline at end of file
from attr import s
from flatland.envs.step_utils.states import TrainState, StateTransitionSignals
class TrainStateMachine:
......
......@@ -71,9 +71,7 @@ def check_action_on_agent(action, rail, position, direction):
# If transition validity hasn't been checked yet.
if transition_valid is None:
transition_valid = rail.get_transition( # TODO: Dipam - Read this one
(*position, direction),
new_direction)
transition_valid = rail.get_transition( (*position, direction), new_direction)
return new_cell_valid, new_direction, new_position, transition_valid
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment