Commit 632707ac authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

step updates WIP

parent d9d19b88
from enum import IntEnum
from flatland.envs.malfunction_generators import Malfunction
from itertools import starmap
from typing import Tuple, Optional, NamedTuple
......@@ -13,7 +14,13 @@ class RailAgentStatus(IntEnum):
ACTIVE = 1 # in grid (position is not None), not done -> prediction is remaining path
DONE = 2 # in grid (position is not None), but done -> prediction is stay at target forever
DONE_REMOVED = 3 # removed from grid (position is None) -> prediction is None
class TrainState(IntEnum):
DONE = 4
Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
('initial_direction', Grid4TransitionsEnum),
......@@ -7,6 +7,7 @@ from enum import IntEnum
from typing import List, NamedTuple, Optional, Dict, Tuple
import numpy as np
from numpy.testing._private.utils import import_nose
from flatland.core.env import Environment
......@@ -84,6 +85,10 @@ class RailEnvActions(IntEnum):
4: 'S',
def is_moving_action(action):
return action in [1,2,3]
RailEnvGridPos = NamedTuple('RailEnvGridPos', [('r', int), ('c', int)])
RailEnvNextAction = NamedTuple('RailEnvNextAction', [('action', RailEnvActions), ('next_position', RailEnvGridPos),
......@@ -493,6 +498,69 @@ class RailEnv(Environment):
agent.malfunction_data['nr_malfunctions'] += 1
def preprocess_action(self, *args, **kwargs):
# TODO : Dipam - Temporarily added - Though I kind of like this system - need thoughts from others?
from flatland.envs.step_utils.action_preprocessing import preprocess_action
return preprocess_action(*args, **kwargs)
def apply_action_independent(self, action, rail, position, direction):
from flatland.envs.step_utils.action_preprocessing import check_action
if RailEnvActions.is_moving_action(action):
new_direction, _ = check_action(agent, action)
new_position = get_new_position(position, new_direction)
new_position, new_direction = position, direction
return new_position, direction
def step_new(self, action_dict):
# TODO: Dipam - Add basic bookkeeping code
for i_agent, agent in enumerate(self.agents):
action = action_dict[i_agent]
# Skipping if action saved for efficiency
if not self.agent.action_saver.is_action_saved:
# Preprocess action
action = self.preprocess_action(action, agent.state, self.rail, agent.position, agent.direction)
# Speed counting
if agent.speed_counter.is_cell_entry:
# Save action
# When cell exit occurs apply saved action independent of other agents
if agent.speed_counter.is_cell_exit and agent.action_saver.is_action_saved:
saved_action = agent.action_saver.saved_action
# Apply action and get temporary new position and direction
# TODO: Dipam - Could change name here to make it more obvious that its without conflict checks
# TODO: Dipam - Important - This won't handle all the possible additions to motion check like "None, None"
# TODO: Dipam - Important - Stop penalty will not be applied if saved_action checks for only moving actions
temp_new_position, temp_new_direction = self.apply_action_independent(saved_action, self.rail, agent.position, agent.direction)
temp_new_position, temp_new_direction = agent.position, agent.direction
self.motionCheck.addAgent(i_agent, agent.position, temp_new_position)
# Find conflicts
# Modify conflicted positions and select one of them randomly to go to new position
# for i_agent, agent in enumerate(self.agents):
# Update posiitions
# Update states
# Update rewards
# Update counters (malfunction and speed)
def step(self, action_dict_: Dict[int, RailEnvActions]):
from flatland.envs.agent_utils import TrainState
from flatland.envs.rail_env import RailEnvActions
def process_illegal_action(action: RailEnvActions):
# TODO - Dipam : This check is kind of weird, change this
if action is None or action not in RailEnvActions._value2member_map_:
return RailEnvActions.DO_NOTHING
return action
def process_do_nothing(state: TrainState):
if state == TrainState.MOVING:
action = RailEnvActions.MOVE_FORWARD
action = RailEnvActions.STOP_MOVING
return action
def process_left_right(action, state, rail, position, direction):
if not check_valid_action(action, state, rail, position, direction):
action = RailEnvActions.MOVE_FORWARD
return action
def check_valid_action(action, state, rail, position, direction):
_, new_cell_valid, _, _, transition_valid = check_action_on_agent(action, state, rail, position, direction)
action_is_valid = new_cell_valid and transition_valid
return action_is_valid
def preprocess_action(action, state, rail, position, direction):
Preprocesses actions to handle different situations of usage of action based on context
- LEFT/RIGHT is converted to FORWARD if left/right is not available and train is moving
- DO_NOTHING is converted to FORWARD if train is moving
- DO_NOTHING is converted to STOP_MOVING if train is moving
if state == TrainState.WAITING:
action = RailEnvActions.DO_NOTHING
action = process_illegal_action(action)
if action == RailEnvActions.DO_NOTHING:
action = process_do_nothing(state)
elif action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]:
action = process_left_right(action, state, rail, position, direction)
if not check_valid_action(action, state, rail, position, direction):
action = RailEnvActions.STOP_MOVING
return action
# TODO - Placeholder - these will be renamed and moved out later
from flatland.envs.rail_env import fast_position_equal, fast_count_nonzero, fast_argmax, fast_clip, get_new_position
# TODO - Dipam - Improve these functions?
def check_action(action):
agent : EnvAgent
action : RailEnvActions
transition_valid = None
possible_transitions = rail.get_transitions(*position, direction)
num_transitions = fast_count_nonzero(possible_transitions)
new_direction = direction
if action == RailEnvActions.MOVE_LEFT:
new_direction = RailEnvActions.MOVE_FORWARD
if num_transitions <= 1:
transition_valid = False
elif action == RailEnvActions.MOVE_RIGHT:
new_direction = RailEnvActions.MOVE_FORWARD
if num_transitions <= 1:
transition_valid = False
new_direction %= 4 # Dipam : Why?
if action == RailEnvActions.MOVE_FORWARD and num_transitions == 1:
# - dead-end, straight line or curved line;
# new_direction will be the only valid transition
# - take only available transition
new_direction = fast_argmax(possible_transitions)
transition_valid = True
return new_direction, transition_valid
def check_action_on_agent(action, state, rail, position, direction):
action : RailEnvActions
agent : EnvAgent
Is it a legal move?
1) transition allows the new_direction in the cell,
2) the new cell is not empty (case 0),
3) the cell is free, i.e., no agent is currently in that cell
# compute number of possible transitions in the current
# cell used to check for invalid actions
new_direction, transition_valid = check_action(agent, action)
new_position = get_new_position(position, new_direction)
new_cell_valid = (
fast_position_equal( # Check the new position is still in the grid
fast_clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
and # check the new position has some transitions (ie is not an empty cell)
rail.get_full_transitions(*new_position) > 0)
# If transition validity hasn't been checked yet.
if transition_valid is None:
transition_valid = rail.get_transition(
(*position, direction),
# only call cell_free() if new cell is inside the scene
if new_cell_valid:
# Check the new position is not the same as any of the existing agent positions
# (including itself, for simplicity, since it is moving)
cell_free = self.cell_free(new_position)
# if new cell is outside of scene -> cell_free is False
cell_free = False
return cell_free, new_cell_valid, new_direction, new_position, transition_valid
from flatland.envs.agent_utils import TrainState
from flatland.envs.rail_env import RailEnvActions
class ActionSaver:
def __init__(self):
self.saved_action = None
def is_action_saved(self):
return not RailEnvActions.is_moving_action(self.saved_action)
def save_action_if_allowed(self, action):
if not self.is_action_saved and RailEnvActions.is_moving_action(action):
self.saved_action = action
def clear_saved_action(self):
self.saved_action = None
import numpy as np
from flatland.envs.agent_utils import TrainState
class SpeedTracker:
def __init__(self, speed):
self.speed = speed
self.max_count = int(np.ceil(1/speed))
def update_counter(self, state):
if state == TrainState.MOVING:
self.counter += 1
self.counter = self.counter % self.max_count
def is_cell_exit(self):
return self.counter == 0
def is_cell_entry(self):
return self.counter == self.max_count - 1
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment