Commit 0fa980ba authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

change arbirary dictionaries to dataclasses

parent 359633c4
......@@ -2,25 +2,22 @@
Definition of the RailEnv environment.
"""
import random
# TODO: _ this is a global method --> utils or remove later
from typing import List, NamedTuple, Optional, Dict, Tuple
import numpy as np
from numpy.lib.shape_base import vsplit
from numpy.testing._private.utils import import_nose
from typing import List, Optional, Dict, Tuple
import numpy as np
from gym.utils import seeding
from dataclasses import dataclass
from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.grid.grid_utils import IntVector2D, position_to_coordinate
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import Agent, EnvAgent
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.distance_map import DistanceMap
from flatland.envs.rail_env_action import RailEnvActions
# Need to use circular imports for persistence.
from flatland.envs import malfunction_generators as mal_gen
from flatland.envs import rail_generators as rail_gen
from flatland.envs import line_generators as line_gen
......@@ -29,31 +26,11 @@ from flatland.envs import persistence
from flatland.envs import agent_chains as ac
from flatland.envs.observations import GlobalObsForRailEnv
from gym.utils import seeding
# Direct import of objects / classes does not work with circular imports.
# from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionProcessData
# from flatland.envs.observations import GlobalObsForRailEnv
# from flatland.envs.rail_generators import random_rail_generator, RailGenerator
# from flatland.envs.line_generators import random_line_generator, LineGenerator
from flatland.envs.timetable_generators import timetable_generator
from flatland.envs.step_utils.states import TrainState
from flatland.envs.step_utils.transition_utils import check_action
# Env Step Facelift imports
from flatland.envs.step_utils.action_preprocessing import preprocess_raw_action, preprocess_moving_action, preprocess_action_when_waiting
# Adrian Egli performance fix (the fast methods brings more than 50%)
def fast_isclose(a, b, rtol):
return (a < (b + rtol)) or (a < (b - rtol))
def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
if pos_1 is None: # TODO: Dipam - Consider making default of agent.position as (-1, -1) instead of None
return False
else:
return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
from flatland.envs.step_utils.states import TrainState, StateTransitionSignals
from flatland.envs.step_utils import transition_utils
from flatland.envs.step_utils import action_preprocessing
class RailEnv(Environment):
"""
......@@ -406,22 +383,35 @@ class RailEnv(Environment):
def apply_action_independent(self, action, rail, position, direction):
if action.is_moving_action():
new_direction, _ = check_action(action, position, direction, rail)
new_direction, _ = transition_utils.check_action(action, position, direction, rail)
new_position = get_new_position(position, new_direction)
else:
new_position, new_direction = position, direction
return new_position, direction
def generate_state_transition_signals(self, agent, preprocessed_action, movement_allowed):
st_signals = {}
st_signals = StateTransitionSignals()
st_signals['malfunction_onset'] = agent.malfunction_handler.in_malfunction
st_signals['malfunction_counter_complete'] = agent.malfunction_handler.malfunction_counter_complete
st_signals['earliest_departure_reached'] = self._elapsed_steps >= agent.earliest_departure
st_signals['stop_action_given'] = (preprocessed_action == RailEnvActions.STOP_MOVING)
st_signals['valid_movement_action_given'] = preprocessed_action.is_moving_action()
st_signals['target_reached'] = fast_position_equal(agent.position, agent.target)
st_signals['movement_conflict'] = (not movement_allowed) and agent.speed_counter.is_cell_exit # TODO: Modify motion check to provide proper conflict information
# Malfunction onset - Malfunction starts
st_signals.malfunction_onset = agent.malfunction_handler.in_malfunction
# Malfunction counter complete - Malfunction ends next timestep
st_signals.malfunction_counter_complete = agent.malfunction_handler.malfunction_counter_complete
# Earliest departure reached - Train is allowed to move now
st_signals.earliest_departure_reached = self._elapsed_steps >= agent.earliest_departure
# Stop Action Given
st_signals.stop_action_given = (preprocessed_action == RailEnvActions.STOP_MOVING)
# Valid Movement action Given
st_signals.valid_movement_action_given = preprocessed_action.is_moving_action()
# Target Reached
st_signals.target_reached = fast_position_equal(agent.position, agent.target)
# Movement conflict - Multiple trains trying to move into same cell
st_signals.movement_conflict = (not movement_allowed) and agent.speed_counter.is_cell_exit # TODO: Modify motion check to provide proper conflict information
return st_signals
......@@ -489,7 +479,7 @@ class RailEnv(Environment):
self.motionCheck = ac.MotionCheck() # reset the motion check
temp_saved_data = {} # TODO : Change name
temp_transition_data = {}
for i_agent, agent in enumerate(self.agents): # TODO: Important - Do not use i_agent like this, use agent.handle if needed
# Generate malfunction
......@@ -500,15 +490,15 @@ class RailEnv(Environment):
# TODO: Add the bottom stuff to separate function(s)
# Preprocess action
action = preprocess_raw_action(action, agent.state)
action = preprocess_action_when_waiting(action, agent.state)
action = action_preprocessing.preprocess_raw_action(action, agent.state)
action = action_preprocessing.preprocess_action_when_waiting(action, agent.state)
# Try moving actions on current position
current_position, current_direction = agent.position, agent.direction
agent_not_on_map = current_position is None
if agent_not_on_map: # Agent not added on map yet
current_position, current_direction = agent.initial_position, agent.initial_direction
action = preprocess_moving_action(action, self.rail, current_position, current_direction)
action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction)
# Save moving actions in not already saved
agent.action_saver.save_action_if_allowed(action, agent.state)
......@@ -516,24 +506,25 @@ class RailEnv(Environment):
# Calculate new position
# Add agent to the map if not on it yet
if agent_not_on_map and agent.action_saver.is_action_saved:
temp_new_position = agent.initial_position
temp_new_direction = agent.initial_direction
new_position = agent.initial_position
new_direction = agent.initial_direction
preprocessed_action = action
# When cell exit occurs apply saved action independent of other agents
elif agent.speed_counter.is_cell_exit and agent.action_saver.is_action_saved:
saved_action = agent.action_saver.saved_action
# Apply action independent of other agents and get temporary new position and direction
temp_pd = self.apply_action_independent(saved_action, self.rail, agent.position, agent.direction)
temp_new_position, temp_new_direction = temp_pd
pd = self.apply_action_independent(saved_action, self.rail, agent.position, agent.direction)
new_position, new_direction = pd
preprocessed_action = saved_action
else:
temp_new_position, temp_new_direction = agent.position, agent.direction
new_position, new_direction = agent.position, agent.direction
preprocessed_action = action
# TODO: Saving temporary positon shouldn't be needed if recheck of position is not needed later (see TAG#1)
temp_saved_data[i_agent] = temp_new_position, temp_new_direction, preprocessed_action
self.motionCheck.addAgent(i_agent, agent.position, temp_new_position)
temp_transition_data[i_agent] = AgentTransitionData(position=new_position,
direction=new_direction,
preprocessed_action=preprocessed_action)
self.motionCheck.addAgent(i_agent, agent.position, new_position)
# Find conflicts
# TODO : Important - Modify conflicted positions and select one of them randomly to go to new position
......@@ -541,23 +532,19 @@ class RailEnv(Environment):
for agent in self.agents:
i_agent = agent.handle
agent_transition_data = temp_transition_data[i_agent]
## Update positions
movement_allowed, _ = self.motionCheck.check_motion(i_agent, agent.position) # TODO: Remove final_new_postion from motioncheck
# TODO : Important : Original code rechecks the next position here again - not sure why? TAG#1
preprocessed_action = temp_saved_data[i_agent][2] # TODO : Important : Make this namedtuple or class
# TODO : Looks like a hacky conditionm, improve the handling
if agent.malfunction_handler.in_malfunction:
movement_allowed = False
else:
movement_allowed, _ = self.motionCheck.check_motion(i_agent, agent.position) # TODO: Remove final_new_postion from motioncheck
if movement_allowed:
final_new_position, final_new_direction = temp_saved_data[i_agent][:2] # TODO : Important : Make this namedtuple or class
else:
final_new_position = agent.position
final_new_direction = agent.direction
agent.position = final_new_position
agent.direction = final_new_direction
agent.position = agent_transition_data.position
agent.direction = agent_transition_data.direction
preprocessed_action = agent_transition_data.preprocessed_action
## Update states
state_transition_signals = self.generate_state_transition_signals(agent, preprocessed_action, movement_allowed)
......@@ -565,8 +552,8 @@ class RailEnv(Environment):
agent.state_machine.step()
agent.state = agent.state_machine.state # TODO : Make this a property instead?
# TODO : Important : Looks like a hacky condiition, improve the handling
if agent.state == TrainState.DONE:
# Remove agent is required
if self.remove_agents_at_target and agent.state == TrainState.DONE:
agent.position = None
## Update rewards
......@@ -661,3 +648,21 @@ class RailEnv(Environment):
def save(self, filename):
print("deprecated call to env.save() - pls call RailEnvPersister.save()")
persistence.RailEnvPersister.save(self, filename)
@dataclass(repr=True)
class AgentTransitionData:
""" Class for keeping track of temporary agent data for position update """
position : Tuple[int, int]
direction : Grid4Transitions
preprocessed_action : RailEnvActions
# Adrian Egli performance fix (the fast methods brings more than 50%)
def fast_isclose(a, b, rtol):
return (a < (b + rtol)) or (a < (b - rtol))
def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
if pos_1 is None: # TODO: Dipam - Consider making default of agent.position as (-1, -1) instead of None
return False
else:
return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
......@@ -20,7 +20,7 @@ class RailEnvActions(IntEnum):
}[a]
@classmethod
def check_valid_action(cls, action):
def is_action_valid(cls, action):
return action in cls._value2member_map_
def is_moving_action(self):
......
......@@ -5,7 +5,7 @@ from flatland.envs.step_utils.transition_utils import check_valid_action
def process_illegal_action(action: RailEnvActions):
if not RailEnvActions.check_valid_action(action):
if not RailEnvActions.is_action_valid(action):
return RailEnvActions.DO_NOTHING
else:
return RailEnvActions(action)
......
from attr import s
from flatland.envs.step_utils.states import TrainState
from flatland.envs.step_utils.states import TrainState, StateTransitionSignals
class TrainStateMachine:
def __init__(self, initial_state=TrainState.WAITING):
self._initial_state = initial_state
self._state = initial_state
self.st_signals = {} # State Transition Signals # TODO: Make this namedtuple
self.st_signals = StateTransitionSignals()
self.next_state = None
def _handle_waiting(self):
......@@ -13,25 +13,25 @@ class TrainStateMachine:
# TODO: Important - The malfunction handling is not like proper state machine
# Both transition signals can happen at the same time
# Atleast mention it in the diagram
if self.st_signals['malfunction_onset']:
if self.st_signals.malfunction_onset:
self.next_state = TrainState.MALFUNCTION_OFF_MAP
elif self.st_signals['earliest_departure_reached']:
elif self.st_signals.earliest_departure_reached:
self.next_state = TrainState.READY_TO_DEPART
else:
self.next_state = TrainState.WAITING
def _handle_ready_to_depart(self):
""" Can only go to MOVING if a valid action is provided """
if self.st_signals['malfunction_onset']:
if self.st_signals.malfunction_onset:
self.next_state = TrainState.MALFUNCTION_OFF_MAP
elif self.st_signals['valid_movement_action_given']:
elif self.st_signals.valid_movement_action_given:
self.next_state = TrainState.MOVING
else:
self.next_state = TrainState.READY_TO_DEPART
def _handle_malfunction_off_map(self):
if self.st_signals['malfunction_counter_complete']:
if self.st_signals['earliest_departure_reached']:
if self.st_signals.malfunction_counter_complete:
if self.st_signals.earliest_departure_reached:
self.next_state = TrainState.READY_TO_DEPART
else:
self.next_state = TrainState.STOPPED
......@@ -39,29 +39,29 @@ class TrainStateMachine:
self.next_state = TrainState.WAITING
def _handle_moving(self):
if self.st_signals['malfunction_onset']:
if self.st_signals.malfunction_onset:
self.next_state = TrainState.MALFUNCTION
elif self.st_signals['target_reached']:
elif self.st_signals.target_reached:
self.next_state = TrainState.DONE
elif self.st_signals['stop_action_given'] or self.st_signals['movement_conflict']:
elif self.st_signals.stop_action_given or self.st_signals.movement_conflict:
self.next_state = TrainState.STOPPED
else:
self.next_state = TrainState.MOVING
def _handle_stopped(self):
if self.st_signals['malfunction_onset']:
if self.st_signals.malfunction_onset:
self.next_state = TrainState.MALFUNCTION
elif self.st_signals['valid_movement_action_given']:
elif self.st_signals.valid_movement_action_given:
self.next_state = TrainState.MOVING
else:
self.next_state = TrainState.STOPPED
def _handle_malfunction(self):
if self.st_signals['malfunction_counter_complete'] and \
self.st_signals['valid_movement_action_given']:
if self.st_signals.malfunction_counter_complete and \
self.st_signals.valid_movement_action_given:
self.next_state = TrainState.MOVING
elif self.st_signals['malfunction_counter_complete'] and \
(self.st_signals['stop_action_given'] or self.st_signals['movement_conflict']):
elif self.st_signals.malfunction_counter_complete and \
(self.st_signals.stop_action_given or self.st_signals.movement_conflict):
self.next_state = TrainState.STOPPED
else:
self.next_state = TrainState.MALFUNCTION
......@@ -134,7 +134,7 @@ class TrainStateMachine:
return self.st_signals
def set_transition_signals(self, state_transition_signals):
self.st_signals = state_transition_signals # TODO: Important: Check all keys are present and if not raise error
self.st_signals = state_transition_signals
from enum import IntEnum
from dataclasses import dataclass
class TrainState(IntEnum):
WAITING = 0
READY_TO_DEPART = 1
......@@ -22,6 +22,13 @@ class TrainState(IntEnum):
def is_on_map_state(self):
return self.value in [self.MOVING, self.STOPPED, self.MALFUNCTION]
@dataclass(repr=True)
class StateTransitionSignals:
malfunction_onset : bool = False
malfunction_counter_complete : bool = False
earliest_departure_reached : bool = False
stop_action_given : bool = False
valid_movement_action_given : bool = False
target_reached : bool = False
movement_conflict : bool = False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment