Commit b2587841 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

state machine initial commit without malfunctions

parent a236be35
......@@ -10,6 +10,7 @@ from flatland.envs.schedule_utils import Schedule
from flatland.envs.step_utils.action_saver import ActionSaver
from flatland.envs.step_utils.speed_counter import SpeedCounter
from flatland.envs.step_utils.state_machine import TrainStateMachine
class RailAgentStatus(IntEnum):
......@@ -35,6 +36,7 @@ Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
('speed_counter', SpeedCounter),
('action_saver', ActionSaver),
('state', TrainState),
('state_machine', TrainStateMachine),
])
......@@ -69,6 +71,8 @@ class EnvAgent:
# Env step facelift
action_saver = attrib(default=None)
speed_counter = attrib(default=None)
state_machine = attrib(default=None)
state = attrib(default=TrainState.WAITING, type=TrainState)
status = attrib(default=RailAgentStatus.READY_TO_DEPART, type=RailAgentStatus)
......@@ -102,6 +106,7 @@ class EnvAgent:
self.action_saver.clear_saved_action()
self.speed_counter.reset_counter()
self.state_machine.reset()
def to_agent(self) -> Agent:
return Agent(initial_position=self.initial_position,
......@@ -119,7 +124,8 @@ class EnvAgent:
old_direction=self.old_direction,
old_position=self.old_position,
speed_counter=self.speed_counter,
action_saver=self.action_saver)
action_saver=self.action_saver,
state_machine=self.state_machine)
@classmethod
def from_schedule(cls, schedule: Schedule):
......@@ -142,11 +148,13 @@ class EnvAgent:
action_savers = []
speed_counters = []
state_machines = []
num_agents = len(schedule.agent_positions)
agent_speeds = schedule.agent_speeds or ( [1.0] * num_agents )
for speed in schedule.agent_speeds:
for speed in agent_speeds:
speed_counters.append( SpeedCounter(speed=speed) )
action_savers.append( ActionSaver() )
state_machines.append( TrainStateMachine(initial_state=TrainState.WAITING) )
return list(starmap(EnvAgent, zip(schedule.agent_positions, # TODO : Dipam - Really want to change this way of loading agents
schedule.agent_directions,
......@@ -160,6 +168,7 @@ class EnvAgent:
range(len(schedule.agent_positions)),
action_savers,
speed_counters,
state_machines,
)))
@classmethod
......
......@@ -7,6 +7,7 @@ from enum import IntEnum
from typing import List, NamedTuple, Optional, Dict, Tuple
import numpy as np
from numpy.lib.shape_base import vsplit
from numpy.testing._private.utils import import_nose
......@@ -42,7 +43,7 @@ from flatland.envs.step_utils.states import TrainState
from flatland.envs.step_utils.transition_utils import check_action
# Env Step Facelift imports
from flatland.envs.step_utils.action_preprocessing import preprocess_raw_action, check_moving_action, preprocess_action_when_waiting
from flatland.envs.step_utils.action_preprocessing import preprocess_raw_action, preprocess_moving_action, preprocess_action_when_waiting
# Adrian Egli performance fix (the fast methods brings more than 50%)
def fast_isclose(a, b, rtol):
......@@ -67,7 +68,10 @@ def fast_argmax(possible_transitions: (int, int, int, int)) -> bool:
def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
if pos_1 is None: # TODO: Dipam - Consider making default of agent.position as (-1, -1) instead of None
return False
else:
return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
def fast_count_nonzero(possible_transitions: (int, int, int, int)):
......@@ -489,13 +493,27 @@ class RailEnv(Environment):
else:
new_position, new_direction = position, direction
return new_position, direction
def generate_state_transition_signals(self, agent, preprocessed_action, movement_allowed):
st_signals = {}
# st_signals['malfunction_onset'] = (agent.malfunction_data['malunction'] > 0)
# st_signals['malfunction_counter_complete'] = (agent.mulfunction_data['malfunction'] == 0)
st_signals['malfunction_onset'] = False
st_signals['malfunction_counter_complete'] = False
st_signals['earliest_departure_reached'] = self._elapsed_steps >= agent.earliest_departure
st_signals['stop_action_given'] = (preprocessed_action == RailEnvActions.STOP_MOVING)
st_signals['valid_movement_action_given'] = RailEnvActions.is_moving_action(preprocessed_action)
st_signals['target_reached'] = fast_position_equal(agent.position, agent.target)
st_signals['movement_conflict'] = (not movement_allowed) and agent.speed_counter.is_cell_exit # TODO: Modify motion check to provide proper conflict information
return st_signals
def step(self, action_dict):
self._elapsed_steps += 1
# If we're done, set reward and info_dict and step() is done.
if self.dones["__all__"]:
if self.dones["__all__"]: # TODO: Move boilerplate to different function
self.rewards_dict = {}
info_dict = {
"action_required": {},
......@@ -524,11 +542,14 @@ class RailEnv(Environment):
self.motionCheck = ac.MotionCheck() # reset the motion check
temp_pos_dirs = {} # TODO - Dipam - Needs renaming
for i_agent, agent in enumerate(self.agents):
temp_saved_data = {} # TODO : Change name
for i_agent, agent in enumerate(self.agents): # TODO: Important - Do not use i_agent like this, use agent.handle if needed
# Get action for the agent
action = action_dict.get(i_agent, RailEnvActions.DO_NOTHING)
# TODO: Add the bottom stuff to separate function(s)
# Preprocess action
action = preprocess_raw_action(action, agent.state)
action = preprocess_action_when_waiting(action, agent.state)
......@@ -538,58 +559,66 @@ class RailEnv(Environment):
agent_not_on_map = current_position is None
if agent_not_on_map: # Agent not added on map yet
current_position, current_direction = agent.initial_position, agent.initial_direction
action = check_moving_action(action, agent.state, self.rail, current_position, current_direction)
action = preprocess_moving_action(action, agent.state, self.rail, current_position, current_direction)
# Save moving actions in not already saved
agent.action_saver.save_action_if_allowed(action) # TODO : Important - Can't save action in malfunction state?
agent.action_saver.save_action_if_allowed(action, agent.state)
# Calculate new position
# Add agent to the map if not on it yet
if agent_not_on_map and agent.action_saver.is_action_saved:
temp_new_position = agent.initial_position
temp_new_direction = agent.initial_direction
temp_pos_dirs[i_agent] = temp_new_position, temp_new_direction
# When cell exit occurs apply saved action independent of other agents
elif agent.speed_counter.is_cell_exit and agent.action_saver.is_action_saved:
saved_action = agent.action_saver.saved_action
# Apply action independent of other agents and get temporary new position and direction
import pdb; pdb.set_trace()
temp_pd = self.apply_action_independent(saved_action, self.rail, agent.position, agent.direction)
temp_new_position, temp_new_direction = temp_pd
else:
temp_new_position, temp_new_direction = agent.position, agent.direction
# TODO: Saving temporary positon shouldn't be needed if recheck of position is not needed later (see TAG#1)
temp_saved_data[i_agent] = temp_new_position, temp_new_direction, action
self.motionCheck.addAgent(i_agent, agent.position, temp_new_position)
# Find conflicts
# self.motionCheck.find_conflicts()
# TODO : Important - Modify conflicted positions and select one of them randomly to go to new position
self.motionCheck.find_conflicts()
for i_agent, agent in enumerate(self.agents):
## Update posiitions
final_new_position = temp_pos_dirs[i_agent][0]
final_new_direction = temp_pos_dirs[i_agent][1]
## Update positions
movement_allowed, _ = self.motionCheck.check_motion(i_agent, agent.position) # TODO: Remove final_new_postion from motioncheck
# TODO : Important : Original code rechecks the next position here again - not sure why? TAG#1
preprocessed_action = temp_saved_data[i_agent][2] # TODO : Important : Make this namedtuple or class
if movement_allowed:
final_new_position, final_new_direction = temp_saved_data[i_agent][:2] # TODO : Important : Make this namedtuple or class
else:
final_new_position = agent.position
final_new_direction = agent.direction
agent.position = final_new_position
agent.direction = final_new_direction
## Update states
# agent.state_machine.step()
# agent.state = agent.state_machine.state
state_transition_signals = self.generate_state_transition_signals(agent, preprocessed_action, movement_allowed)
agent.state_machine.set_transition_signals(state_transition_signals)
agent.state_machine.step()
agent.state = agent.state_machine.state # TODO : Make this a property instead?
## Update rewards
# agent.update_rewards()
# self.update_rewards(i_agent, agent, rail)
## Update counters (malfunction and speed)
agent.speed_counter.update_counter(agent.state)
# agent.malfunction_counter.update_counter()
agent.speed_counter.update_counter(agent.state)
# agent.malfunction_counter.update_counter() # TODO : Update this to interface with current malfunction code
# Clear old action when starting in new cell
if agent.speed_counter.is_cell_entry:
agent.action_saver.clear_saved_action()
self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))}
return self._get_observations(), self.rewards_dict, self.dones, info_dict
self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))} # TODO : Remove this
return self._get_observations(), self.rewards_dict, self.dones, info_dict # TODO : Will need changes?
......@@ -634,8 +663,6 @@ class RailEnv(Environment):
self.motionCheck = ac.MotionCheck() # reset the motion check
import pdb; pdb.set_trace()
if not self.close_following:
for i_agent, agent in enumerate(self.agents):
# Reset the step rewards
......
......@@ -6,7 +6,7 @@ from flatland.envs.step_utils.transition_utils import check_valid_action
def process_illegal_action(action: RailEnvActions):
# TODO - Dipam : This check is kind of weird, change this
if action is None or action not in RailEnvActions._value2member_map_:
if action is None or action not in RailEnvActions._value2member_map_:
return RailEnvActions.DO_NOTHING
else:
return action
......@@ -48,7 +48,7 @@ def preprocess_raw_action(action, state):
return action
def check_moving_action(action, state, rail, position, direction):
def preprocess_moving_action(action, state, rail, position, direction):
"""
LEFT/RIGHT is converted to FORWARD if left/right is not available and train is moving
FORWARD is converted to STOP_MOVING if leading to dead end?
......
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.step_utils.states import TrainState
class ActionSaver:
def __init__(self):
......@@ -8,8 +9,10 @@ class ActionSaver:
def is_action_saved(self):
return self.saved_action is not None
def save_action_if_allowed(self, action):
if not self.is_action_saved and RailEnvActions.is_moving_action(action):
def save_action_if_allowed(self, action, state):
if not self.is_action_saved and \
RailEnvActions.is_moving_action(action) and \
not TrainState.is_malfunction_state(state):
self.saved_action = action
def clear_saved_action(self):
......
......@@ -4,7 +4,7 @@ from flatland.envs.step_utils.states import TrainState
class SpeedCounter:
def __init__(self, speed):
self.speed = speed
self.max_count = int(np.ceil(1/speed))
self.max_count = int(1/speed)
def update_counter(self, state):
if state == TrainState.MOVING:
......
from attr import s
from flatland.envs.step_utils.states import TrainState
class TrainStateMachine:
def __init__(self, initial_state=TrainState.WAITING):
self._initial_state = initial_state
self._state = initial_state
self.st_signals = {} # State Transition Signals
self.next_state = None
def _handle_waiting(self):
"""" Waiting state goes to ready to depart when earliest departure is reached"""
# TODO: Important - The malfunction handling is not like proper state machine
# Both transition signals can happen at the same time
# Atleast mention it in the diagram
if self.st_signals['malfunction_onset']:
self.next_state = TrainState.MALFUNCTION_OFF_MAP
elif self.st_signals['earliest_departure_reached']:
self.next_state = TrainState.READY_TO_DEPART
else:
self.next_state = TrainState.WAITING
def _handle_ready_to_depart(self):
""" Can only go to MOVING if a valid action is provided """
if self.st_signals['malfunction_onset']:
self.next_state = TrainState.MALFUNCTION_OFF_MAP
elif self.st_signals['valid_movement_action_given']:
self.next_state = TrainState.MOVING
else:
self.next_state = TrainState.READY_TO_DEPART
def _handle_malfunction_off_map(self):
if self.st_signals['malfunction_counter_complete']:
if self.st_signals['earliest_departure_reached']:
self.next_state = TrainState.READY_TO_DEPART
else:
self.next_state = TrainState.STOPPED
else:
self.next_state = TrainState.WAITING
def _handle_moving(self):
if self.st_signals['malfunction_onset']:
self.next_state = TrainState.MALFUNCTION
elif self.st_signals['target_reached']:
self.next_state = TrainState.DONE
elif self.st_signals['stop_action_given'] or self.st_signals['movement_conflict']:
self.next_state = TrainState.STOPPED
else:
self.next_state = TrainState.MOVING
def _handle_stopped(self):
if self.st_signals['malfunction_onset']:
self.next_state = TrainState.MALFUNCTION
elif self.st_signals['valid_movement_action_given']:
self.next_state = TrainState.MOVING
else:
self.next_state = TrainState.STOPPED
def _handle_malfunction(self):
if self.st_signals['malfunction_counter_complete'] and \
self.st_signals['valid_movement_action_given']:
self.next_state = TrainState.MOVING
elif self.st_signals['malfunction_counter_complete'] and \
(self.st_signals['stop_action_given'] or self.st_signals['movement_conflict']):
self.next_state = TrainState.STOPPED
else:
self.next_state = TrainState.MALFUNCTION
def _handle_done(self):
"""" Done state is terminal """
self.next_state = TrainState.DONE
def calculate_next_state(self, current_state):
# _Handle the current state
if current_state == TrainState.WAITING:
self._handle_waiting()
elif current_state == TrainState.READY_TO_DEPART:
self._handle_ready_to_depart()
elif current_state == TrainState.MALFUNCTION_OFF_MAP:
self._handle_malfunction_off_map()
elif current_state == TrainState.MOVING:
self._handle_moving()
elif current_state == TrainState.STOPPED:
self._handle_stopped()
elif current_state == TrainState.MALFUNCTION:
self._handle_malfunction()
elif current_state == TrainState.DONE:
self._handle_done()
else:
raise ValueError(f"Got unexpected state {current_state}")
def step(self):
""" Steps the state machine to the next state """
current_state = self._state
# Clear next state
self.clear_next_state()
# Handle current state to get next_state
self.calculate_next_state(current_state)
# Set next state
self.set_state(self.next_state)
def clear_next_state(self):
self.next_state = None
def set_state(self, state):
if not TrainState.check_valid_state(state):
raise ValueError(f"Cannot set invalid state {state}")
self._state = state
def reset(self):
self._state = self._initial_state
self.st_signals = {}
self.clear_next_state()
@property
def state(self):
return self._state
@property
def state_transition_signals(self):
return self.st_signals
def set_transition_signals(self, state_transition_signals):
self.st_signals = state_transition_signals # TODO: Important: Check all keys are present and if not raise error
......@@ -3,7 +3,19 @@ from enum import IntEnum
class TrainState(IntEnum):
WAITING = 0
READY_TO_DEPART = 1
MOVING = 1
STOPPED = 2
MALFUNCTION = 3
DONE = 4
\ No newline at end of file
MALFUNCTION_OFF_MAP = 2
MOVING = 3
STOPPED = 4
MALFUNCTION = 5
DONE = 6
@classmethod
def check_valid_state(cls, state):
return state in cls._value2member_map_
@staticmethod
def is_malfunction_state(state):
return state in [2, 5] # TODO: Can this be done with names instead?
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment