Commit 3f501522 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

change RailAgentStatus to TrainState in observatons and predictions

parent 8c357c5b
from typing import Tuple, Optional, NamedTuple
from flatland.envs.rail_trainrun_data_structures import Waypoint
import numpy as np
from typing import Tuple, Optional, NamedTuple, List
from attr import attr, attrs, attrib, Factory
......@@ -124,6 +127,31 @@ class EnvAgent:
state_machine=self.state_machine,
malfunction_handler=self.malfunction_handler)
def get_shortest_path(self, distance_map) -> List[Waypoint]:
from flatland.envs.rail_env_shortest_paths import get_shortest_paths # Circular dep fix
return get_shortest_paths(distance_map=distance_map, agent_handle=self.handle)[self.handle]
def get_travel_time_on_shortest_path(self, distance_map) -> int:
shortest_path = self.get_shortest_path(distance_map)
if shortest_path is not None:
distance = len(shortest_path)
else:
distance = 0
speed = self.speed_data['speed']
return int(np.ceil(distance / speed))
def get_time_remaining_until_latest_arrival(self, elapsed_steps: int) -> int:
return self.latest_arrival - elapsed_steps
def get_current_delay(self, elapsed_steps: int, distance_map) -> int:
'''
+ve if arrival time is projected before latest arrival
-ve if arrival time is projected after latest arrival
'''
return self.get_time_remaining_until_latest_arrival(elapsed_steps) - \
self.get_travel_time_on_shortest_path(distance_map)
@classmethod
def from_line(cls, line: Line):
""" Create a list of EnvAgent from lists of positions, directions and targets
......
......@@ -196,14 +196,12 @@ class TreeObsForRailEnv(ObservationBuilder):
if handle > len(self.env.agents):
print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
agent = self.env.agents[handle] # TODO: handle being treated as index
if agent.status == RailAgentStatus.WAITING:
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.READY_TO_DEPART:
if agent.state.is_off_map_state():
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
elif agent.state.is_on_map_state():
agent_virtual_position = agent.position
elif agent.status == RailAgentStatus.DONE:
elif agent.state == TrainState.DONE:
agent_virtual_position = agent.target
else:
return None
......@@ -343,7 +341,7 @@ class TreeObsForRailEnv(ObservationBuilder):
self._reverse_dir(
self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict:
potential_conflict = tot_dist
if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
if self.env.agents[ca].state == TrainState.DONE and tot_dist < potential_conflict:
potential_conflict = tot_dist
# Look for conflicting paths at distance num_step-1
......@@ -354,7 +352,7 @@ class TreeObsForRailEnv(ObservationBuilder):
and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \
and tot_dist < potential_conflict: # noqa: E125
potential_conflict = tot_dist
if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
if self.env.agents[ca].state == TrainState.DONE and tot_dist < potential_conflict:
potential_conflict = tot_dist
# Look for conflicting paths at distance num_step+1
......@@ -365,7 +363,7 @@ class TreeObsForRailEnv(ObservationBuilder):
self.predicted_dir[post_step][ca])] == 1 \
and tot_dist < potential_conflict: # noqa: E125
potential_conflict = tot_dist
if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
if self.env.agents[ca].state == TrainState.DONE and tot_dist < potential_conflict:
potential_conflict = tot_dist
if position in self.location_has_target and position != agent.target:
......
......@@ -5,11 +5,11 @@ Collection of environment-specific PredictionBuilder.
import numpy as np
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.distance_map import DistanceMap
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
from flatland.utils.ordered_set import OrderedSet
from flatland.envs.step_utils.states import TrainState
class DummyPredictorForRailEnv(PredictionBuilder):
......@@ -48,7 +48,7 @@ class DummyPredictorForRailEnv(PredictionBuilder):
prediction_dict = {}
for agent in agents:
if agent.status != RailAgentStatus.ACTIVE:
if not agent.state.is_on_map_state():
# TODO make this generic
continue
action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
......@@ -126,13 +126,11 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
prediction_dict = {}
for agent in agents:
if agent.status == RailAgentStatus.WAITING:
if agent.state.is_off_map_state():
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.READY_TO_DEPART:
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
elif agent.state.is_on_map_state():
agent_virtual_position = agent.position
elif agent.status == RailAgentStatus.DONE:
elif agent.state == TrainState.DONE:
agent_virtual_position = agent.target
else:
......
......@@ -527,11 +527,10 @@ class RailEnv(Environment):
direction=new_direction,
preprocessed_action=preprocessed_action)
# This is for checking conflicts of agents trying to occupy same cell
# This is for storing and later checking for conflicts of agents trying to occupy same cell
self.motionCheck.addAgent(i_agent, agent.position, new_position)
# Find conflicts
# Find conflicts between trains trying to occupy same cell
self.motionCheck.find_conflicts()
for agent in self.agents:
......
from enum import IntEnum
from dataclasses import dataclass
class TrainState(IntEnum):
WAITING = 0
READY_TO_DEPART = 1
......@@ -22,6 +24,7 @@ class TrainState(IntEnum):
def is_on_map_state(self):
return self.value in [self.MOVING, self.STOPPED, self.MALFUNCTION]
@dataclass(repr=True)
class StateTransitionSignals:
malfunction_onset : bool = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment