Commit a236be35 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

update temporary postions complete

parent 95884b64
from enum import IntEnum
from flatland.envs.malfunction_generators import Malfunction
from flatland.envs.step_utils.states import TrainState
from itertools import starmap
from typing import Tuple, Optional, NamedTuple
......@@ -8,19 +8,15 @@ from attr import attr, attrs, attrib, Factory
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.schedule_utils import Schedule
from flatland.envs.step_utils.action_saver import ActionSaver
from flatland.envs.step_utils.speed_counter import SpeedCounter
class RailAgentStatus(IntEnum):
READY_TO_DEPART = 0 # not in grid yet (position is None) -> prediction as if it were at initial position
ACTIVE = 1 # in grid (position is not None), not done -> prediction is remaining path
DONE = 2 # in grid (position is not None), but done -> prediction is stay at target forever
DONE_REMOVED = 3 # removed from grid (position is None) -> prediction is None
class TrainState(IntEnum):
WAITING = 0
READY_TO_DEPART = 1
MOVING = 1
STOPPED = 2
MALFUNCTION = 3
DONE = 4
Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
('initial_direction', Grid4TransitionsEnum),
......@@ -35,7 +31,11 @@ Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
('status', RailAgentStatus),
('position', Tuple[int, int]),
('old_direction', Grid4TransitionsEnum),
('old_position', Tuple[int, int])])
('old_position', Tuple[int, int]),
('speed_counter', SpeedCounter),
('action_saver', ActionSaver),
('state', TrainState),
])
@attrs
......@@ -66,6 +66,11 @@ class EnvAgent:
handle = attrib(default=None)
# Env step facelift
action_saver = attrib(default=None)
speed_counter = attrib(default=None)
state = attrib(default=TrainState.WAITING, type=TrainState)
status = attrib(default=RailAgentStatus.READY_TO_DEPART, type=RailAgentStatus)
position = attrib(default=None, type=Optional[Tuple[int, int]])
......@@ -73,6 +78,7 @@ class EnvAgent:
old_direction = attrib(default=None)
old_position = attrib(default=None)
def reset(self):
"""
Resets the agents to their initial values of the episode
......@@ -94,12 +100,26 @@ class EnvAgent:
self.malfunction_data['nr_malfunctions'] = 0
self.malfunction_data['moving_before_malfunction'] = False
self.action_saver.clear_saved_action()
self.speed_counter.reset_counter()
def to_agent(self) -> Agent:
return Agent(initial_position=self.initial_position, initial_direction=self.initial_direction,
direction=self.direction, target=self.target, moving=self.moving, earliest_departure=self.earliest_departure,
latest_arrival=self.latest_arrival, speed_data=self.speed_data,
malfunction_data=self.malfunction_data, handle=self.handle, status=self.status,
position=self.position, old_direction=self.old_direction, old_position=self.old_position)
return Agent(initial_position=self.initial_position,
initial_direction=self.initial_direction,
direction=self.direction,
target=self.target,
moving=self.moving,
earliest_departure=self.earliest_departure,
latest_arrival=self.latest_arrival,
speed_data=self.speed_data,
malfunction_data=self.malfunction_data,
handle=self.handle,
status=self.status,
position=self.position,
old_direction=self.old_direction,
old_position=self.old_position,
speed_counter=self.speed_counter,
action_saver=self.action_saver)
@classmethod
def from_schedule(cls, schedule: Schedule):
......@@ -120,7 +140,15 @@ class EnvAgent:
'next_malfunction': 0,
'nr_malfunctions': 0})
return list(starmap(EnvAgent, zip(schedule.agent_positions,
action_savers = []
speed_counters = []
num_agents = len(schedule.agent_positions)
agent_speeds = schedule.agent_speeds or ( [1.0] * num_agents )
for speed in schedule.agent_speeds:
speed_counters.append( SpeedCounter(speed=speed) )
action_savers.append( ActionSaver() )
return list(starmap(EnvAgent, zip(schedule.agent_positions, # TODO : Dipam - Really want to change this way of loading agents
schedule.agent_directions,
schedule.agent_directions,
schedule.agent_targets,
......@@ -129,7 +157,10 @@ class EnvAgent:
[None] * len(schedule.agent_positions), # latest_arrival
speed_datas,
malfunction_datas,
range(len(schedule.agent_positions)))))
range(len(schedule.agent_positions)),
action_savers,
speed_counters,
)))
@classmethod
def load_legacy_static_agent(cls, static_agents_data: Tuple):
......
......@@ -14,7 +14,7 @@ from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.grid.grid_utils import IntVector2D
from flatland.core.grid.grid_utils import IntVector2D, position_to_coordinate
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
from flatland.envs.distance_map import DistanceMap
......@@ -38,6 +38,11 @@ from gym.utils import seeding
# NEW : Imports
from flatland.envs.schedule_time_generators import schedule_time_generator
from flatland.envs.step_utils.states import TrainState
from flatland.envs.step_utils.transition_utils import check_action
# Env Step Facelift imports
from flatland.envs.step_utils.action_preprocessing import preprocess_raw_action, check_moving_action, preprocess_action_when_waiting
# Adrian Egli performance fix (the fast methods brings more than 50%)
def fast_isclose(a, b, rtol):
......@@ -254,6 +259,8 @@ class RailEnv(Environment):
self.close_following = close_following # use close following logic
self.motionCheck = ac.MotionCheck()
self.agent_helpers = {}
def _seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
random.seed(seed)
......@@ -475,70 +482,118 @@ class RailEnv(Environment):
return
def preprocess_action(self, *args, **kwargs):
# TODO : Dipam - Temporarily added - Though I kind of like this system - need thoughts from others?
from flatland.envs.step_utils.action_preprocessing import preprocess_action
return preprocess_action(*args, **kwargs)
def apply_action_independent(self, action, rail, position, direction):
from flatland.envs.step_utils.action_preprocessing import check_action
if RailEnvActions.is_moving_action(action):
new_direction, _ = check_action(agent, action)
new_direction, _ = check_action(action, position, direction, rail)
new_position = get_new_position(position, new_direction)
else:
new_position, new_direction = position, direction
return new_position, direction
def step_new(self, action_dict):
# TODO: Dipam - Add basic bookkeeping code
for i_agent, agent in enumerate(self.agents):
action = action_dict[i_agent]
def step(self, action_dict):
self._elapsed_steps += 1
# Skipping if action saved for efficiency
if not self.agent.action_saver.is_action_saved:
# If we're done, set reward and info_dict and step() is done.
if self.dones["__all__"]:
self.rewards_dict = {}
info_dict = {
"action_required": {},
"malfunction": {},
"speed": {},
"status": {},
}
for i_agent, agent in enumerate(self.agents):
self.rewards_dict[i_agent] = self.global_reward
info_dict["action_required"][i_agent] = False
info_dict["malfunction"][i_agent] = 0
info_dict["speed"][i_agent] = 0
info_dict["status"][i_agent] = agent.status
return self._get_observations(), self.rewards_dict, self.dones, info_dict
# Preprocess action
action = self.preprocess_action(action, agent.state, self.rail, agent.position, agent.direction)
# Reset the step rewards
self.rewards_dict = dict()
info_dict = {
"action_required": {},
"malfunction": {},
"speed": {},
"status": {},
}
have_all_agents_ended = True # boolean flag to check if all agents are done
# Speed counting
if agent.speed_counter.is_cell_entry:
agent.action_saver.clear_saved_action()
self.motionCheck = ac.MotionCheck() # reset the motion check
# Save action
agent.action_saver.save_action_if_allowed(action)
temp_pos_dirs = {} # TODO - Dipam - Needs renaming
for i_agent, agent in enumerate(self.agents):
action = action_dict.get(i_agent, RailEnvActions.DO_NOTHING)
# Preprocess action
action = preprocess_raw_action(action, agent.state)
action = preprocess_action_when_waiting(action, agent.state)
# Try moving actions on current position
current_position, current_direction = agent.position, agent.direction
agent_not_on_map = current_position is None
if agent_not_on_map: # Agent not added on map yet
current_position, current_direction = agent.initial_position, agent.initial_direction
action = check_moving_action(action, agent.state, self.rail, current_position, current_direction)
# Save moving actions in not already saved
agent.action_saver.save_action_if_allowed(action) # TODO : Important - Can't save action in malfunction state?
# Calculate new position
# Add agent to the map if not on it yet
if agent_not_on_map and agent.action_saver.is_action_saved:
temp_new_position = agent.initial_position
temp_new_direction = agent.initial_direction
temp_pos_dirs[i_agent] = temp_new_position, temp_new_direction
# When cell exit occurs apply saved action independent of other agents
if agent.speed_counter.is_cell_exit and agent.action_saver.is_action_saved:
elif agent.speed_counter.is_cell_exit and agent.action_saver.is_action_saved:
saved_action = agent.action_saver.saved_action
# Apply action and get temporary new position and direction
# TODO: Dipam - Could change name here to make it more obvious that its without conflict checks
# TODO: Dipam - Important - This won't handle all the possible additions to motion check like "None, None"
# TODO: Dipam - Important - Stop penalty will not be applied if saved_action checks for only moving actions
temp_new_position, temp_new_direction = self.apply_action_independent(saved_action, self.rail, agent.position, agent.direction)
# Apply action independent of other agents and get temporary new position and direction
import pdb; pdb.set_trace()
temp_pd = self.apply_action_independent(saved_action, self.rail, agent.position, agent.direction)
temp_new_position, temp_new_direction = temp_pd
else:
temp_new_position, temp_new_direction = agent.position, agent.direction
self.motionCheck.addAgent(i_agent, agent.position, temp_new_position)
# Find conflicts
# self.motionCheck.find_conflicts()
# Modify conflicted positions and select one of them randomly to go to new position
# TODO : Important - Modify conflicted positions and select one of them randomly to go to new position
# for i_agent, agent in enumerate(self.agents):
# Update posiitions
# Update states
# Update rewards
# Update counters (malfunction and speed)
for i_agent, agent in enumerate(self.agents):
## Update posiitions
final_new_position = temp_pos_dirs[i_agent][0]
final_new_direction = temp_pos_dirs[i_agent][1]
agent.position = final_new_position
agent.direction = final_new_direction
## Update states
# agent.state_machine.step()
# agent.state = agent.state_machine.state
## Update rewards
# agent.update_rewards()
## Update counters (malfunction and speed)
agent.speed_counter.update_counter(agent.state)
# agent.malfunction_counter.update_counter()
# Clear old action when starting in new cell
if agent.speed_counter.is_cell_entry:
agent.action_saver.clear_saved_action()
self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))}
return self._get_observations(), self.rewards_dict, self.dones, info_dict
def step(self, action_dict_: Dict[int, RailEnvActions]):
def step_old(self, action_dict_: Dict[int, RailEnvActions]):
"""
Updates rewards for the agents at a step.
......@@ -579,6 +634,8 @@ class RailEnv(Environment):
self.motionCheck = ac.MotionCheck() # reset the motion check
import pdb; pdb.set_trace()
if not self.close_following:
for i_agent, agent in enumerate(self.agents):
# Reset the step rewards
......@@ -646,144 +703,6 @@ class RailEnv(Environment):
return self._get_observations(), self.rewards_dict, self.dones, info_dict
def _step_agent(self, i_agent, action: Optional[RailEnvActions] = None):
"""
Performs a step and step, start and stop penalty on a single agent in the following sub steps:
- malfunction
- action handling if at the beginning of cell
- movement
Parameters
----------
i_agent : int
action_dict_ : Dict[int,RailEnvActions]
"""
agent = self.agents[i_agent]
if agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]: # this agent has already completed...
return
# agent gets active by a MOVE_* action and if c
if agent.status == RailAgentStatus.READY_TO_DEPART:
initial_cell_free = self.cell_free(agent.initial_position)
is_action_starting = action in [
RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_FORWARD]
if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT,
RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position):
agent.status = RailAgentStatus.ACTIVE
self._set_agent_to_initial_position(agent, agent.initial_position)
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
return
else:
# TODO: Here we need to check for the departure time in future releases with full schedules
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
return
agent.old_direction = agent.direction
agent.old_position = agent.position
# if agent is broken, actions are ignored and agent does not move.
# full step penalty in this case
if agent.malfunction_data['malfunction'] > 0:
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
return
# Is the agent at the beginning of the cell? Then, it can take an action.
# As long as the agent is malfunctioning or stopped at the beginning of the cell,
# different actions may be taken!
if fast_isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03):
# No action has been supplied for this agent -> set DO_NOTHING as default
if action is None:
action = RailEnvActions.DO_NOTHING
if action < 0 or action > len(RailEnvActions):
print('ERROR: illegal action=', action,
'for agent with index=', i_agent,
'"DO NOTHING" will be executed instead')
action = RailEnvActions.DO_NOTHING
if action == RailEnvActions.DO_NOTHING and agent.moving:
# Keep moving
action = RailEnvActions.MOVE_FORWARD
if action == RailEnvActions.STOP_MOVING and agent.moving:
# Only allow halting an agent on entering new cells.
agent.moving = False
self.rewards_dict[i_agent] += self.stop_penalty
if not agent.moving and not (
action == RailEnvActions.DO_NOTHING or
action == RailEnvActions.STOP_MOVING):
# Allow agent to start with any forward or direction action
agent.moving = True
self.rewards_dict[i_agent] += self.start_penalty
# Store the action if action is moving
# If not moving, the action will be stored when the agent starts moving again.
if agent.moving:
_action_stored = False
_, new_cell_valid, new_direction, new_position, transition_valid = \
self._check_action_on_agent(action, agent)
if all([new_cell_valid, transition_valid]):
agent.speed_data['transition_action_on_cellexit'] = action
_action_stored = True
else:
# But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
# try to keep moving forward!
if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT):
_, new_cell_valid, new_direction, new_position, transition_valid = \
self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)
if all([new_cell_valid, transition_valid]):
agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
_action_stored = True
if not _action_stored:
# If the agent cannot move due to an invalid transition, we set its state to not moving
self.rewards_dict[i_agent] += self.invalid_action_penalty
self.rewards_dict[i_agent] += self.stop_penalty
agent.moving = False
# Now perform a movement.
# If agent.moving, increment the position_fraction by the speed of the agent
# If the new position fraction is >= 1, reset to 0, and perform the stored
# transition_action_on_cellexit if the cell is free.
if agent.moving:
agent.speed_data['position_fraction'] += agent.speed_data['speed']
if agent.speed_data['position_fraction'] > 1.0 or fast_isclose(agent.speed_data['position_fraction'], 1.0,
rtol=1e-03):
# Perform stored action to transition to the next cell as soon as cell is free
# Notice that we've already checked new_cell_valid and transition valid when we stored the action,
# so we only have to check cell_free now!
# Traditional check that next cell is free
# cell and transition validity was checked when we stored transition_action_on_cellexit!
cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent(
agent.speed_data['transition_action_on_cellexit'], agent)
# N.B. validity of new_cell and transition should have been verified before the action was stored!
assert new_cell_valid
assert transition_valid
if cell_free:
self._move_agent_to_new_position(agent, new_position)
agent.direction = new_direction
agent.speed_data['position_fraction'] = 0.0
# has the agent reached its target?
if np.equal(agent.position, agent.target).all():
agent.status = RailAgentStatus.DONE
self.dones[i_agent] = True
self.active_agents.remove(i_agent)
agent.moving = False
self._remove_agent_from_scene(agent)
else:
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
else:
# step penalty if not moving (stopped now or before)
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
def _step_agent_cf(self, i_agent, action: Optional[RailEnvActions] = None):
""" "close following" version of step_agent.
"""
......
from flatland.core.grid.grid_utils import position_to_coordinate
from flatland.envs.agent_utils import TrainState
from flatland.envs.rail_env import RailEnvActions
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.step_utils.transition_utils import check_valid_action
def process_illegal_action(action: RailEnvActions):
# TODO - Dipam : This check is kind of weird, change this
if action is None or action not in RailEnvActions._value2member_map_:
if action is None or action not in RailEnvActions._value2member_map_:
return RailEnvActions.DO_NOTHING
else:
return action
def process_do_nothing(state: TrainState):
if state == TrainState.MOVING:
action = RailEnvActions.MOVE_FORWARD
else:
action = RailEnvActions.STOP_MOVING
return action
def process_left_right(action, state, rail, position, direction):
if not check_valid_action(action, state, rail, position, direction):
action = RailEnvActions.MOVE_FORWARD
return action
def check_valid_action(action, state, rail, position, direction):
_, new_cell_valid, _, _, transition_valid = check_action_on_agent(action, state, rail, position, direction)
action_is_valid = new_cell_valid and transition_valid
return action_is_valid
def preprocess_action(action, state, rail, position, direction):
"""
Preprocesses actions to handle different situations of usage of action based on context
- LEFT/RIGHT is converted to FORWARD if left/right is not available and train is moving
- DO_NOTHING is converted to FORWARD if train is moving
- DO_NOTHING is converted to STOP_MOVING if train is moving
"""
if state == TrainState.WAITING:
action = RailEnvActions.DO_NOTHING
action = process_illegal_action(action)
if action == RailEnvActions.DO_NOTHING:
action = process_do_nothing(state)
elif action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]:
action = process_left_right(action, state, rail, position, direction)
if not check_valid_action(action, state, rail, position, direction):
action = RailEnvActions.STOP_MOVING
return action
# TODO - Placeholder - these will be renamed and moved out later
from flatland.envs.rail_env import fast_position_equal, fast_count_nonzero, fast_argmax, fast_clip, get_new_position
# TODO - Dipam - Improve these functions?
def check_action(action):
"""
Parameters
----------
agent : EnvAgent
action : RailEnvActions
Returns
-------
Tuple[Grid4TransitionsEnum,Tuple[int,int]]
def process_do_nothing(state: TrainState):
if state == TrainState.MOVING:
action = RailEnvActions.MOVE_FORWARD
else:
action = RailEnvActions.STOP_MOVING
return action
"""
transition_valid = None
possible_transitions = rail.get_transitions(*position, direction)
num_transitions = fast_count_nonzero(possible_transitions)
new_direction = direction
if action == RailEnvActions.MOVE_LEFT:
new_direction = RailEnvActions.MOVE_FORWARD
if num_transitions <= 1:
transition_valid = False
elif action == RailEnvActions.MOVE_RIGHT:
new_direction = RailEnvActions.MOVE_FORWARD
if num_transitions <= 1:
transition_valid = False
new_direction %= 4 # Dipam : Why?
if action == RailEnvActions.MOVE_FORWARD and num_transitions == 1:
# - dead-end, straight line or curved line;
# new_direction will be the only valid transition
# - take only available transition
new_direction = fast_argmax(possible_transitions)
transition_valid = True
return new_direction, transition_valid
def check_action_on_agent(action, state, rail, position, direction):
"""
Parameters
----------
action : RailEnvActions
agent : EnvAgent
Returns
-------
bool
Is it a legal move?
1) transition allows the new_direction in the cell,
2) the new cell is not empty (case 0),
3) the cell is free, i.e., no agent is currently in that cell
"""
# compute number of possible transitions in the current
# cell used to check for invalid actions
new_direction, transition_valid = check_action(agent, action)
new_position = get_new_position(position, new_direction)
new_cell_valid = (
fast_position_equal( # Check the new position is still in the grid
new_position,
fast_clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
and # check the new position has some transitions (ie is not an empty cell)
rail.get_full_transitions(*new_position) > 0)
# If transition validity hasn't been checked yet.
if transition_valid is None:
transition_valid = rail.get_transition(
(*position, direction),
new_direction)
# only call cell_free() if new cell is inside the scene
if new_cell_valid: